deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java

Summary

Maintainability
F
3 wks
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.jcublas.ops.executioner;


import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.nd4j.common.base.Preconditions;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.tad.DeviceTADManager;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.TadPack;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.api.shape.options.ArrayType;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.bindings.Nd4jCuda;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaUtf8Buffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.nativeblas.*;

import java.util.*;

import static org.bytedeco.cuda.global.cudart.*;


/**
 * JCuda executioner.
 * <p/>
 * Runs ops directly on the gpu
 *
 * If requested Op doesn't exist within GPU context, DefaultOpExecutioner will be used, with arrays/buffers updated after that.
 *
 * @author Adam Gibson
 * @author raver119@gmail.com
 */
@Slf4j
public class CudaExecutioner extends DefaultOpExecutioner {

    protected static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();


    @Getter
    protected static TADManager tadManager = new DeviceTADManager();
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal<>();
    protected volatile transient Properties properties;

    protected ThreadLocal<String> lastOp = new ThreadLocal<>();

    protected Map<String, CustomOpDescriptor> customOps = null;

    protected AtomicBoolean experimentalMode = new AtomicBoolean(false);

    public CudaExecutioner() {
        experimentalMode.set(nativeOps.isExperimentalEnabled());
    }

    public NativeOps getNativeOps() {
        return nativeOps;
    }

    @Override
    public String getLastOp() {
        return lastOp.get();
    }

    @Override
    public INDArray exec(BroadcastOp op) {
        long st = profilingConfigurableHookIn(op);

        checkForCompression(op);

        val dimension = op.dimensions().toLongVector();

        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        val context = AtomicAllocator.getInstance().getDeviceContext();

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        Pointer hostYShapeInfo =
                op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        Pointer hostZShapeInfo =
                op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());

        val x = op.x() == null ? null : op.x().data().opaqueBuffer();
        val y = op.y() == null ? null : op.y().data().opaqueBuffer();
        val z = op.z() == null ? null : op.z().data().opaqueBuffer();

        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);

        Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);

        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);

        DataBuffer offsets = tadBuffers.getSecond();
        Pointer devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);

        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;

        // that's the place where we're going to have second TAD in place
        Pair<DataBuffer, DataBuffer> tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);

        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getSecond(), context);

        PointerPointer xShapeInfoHostPointer = extraz.get().put(
                AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
                context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
                hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets,
                devTadShapeInfoZ, devTadOffsetsZ);


        switch (op.getOpType()) {
            case BROADCAST:
                nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(),
                        x, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), (LongPointer) xShapeInfo,
                        y, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context),
                        z, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
                        op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(), (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.dimensions().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.dimensions().shapeInfoDataBuffer(), context));
                break;
            case BROADCAST_BOOL:
                nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(),
                        x, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.x().shapeInfoDataBuffer()), (LongPointer) xShapeInfo,
                        y, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.y().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(),context),
                        z, (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.z().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
                        null,
                        op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(), (LongPointer) AtomicAllocator.getInstance().getHostPointer(op.dimensions().shapeInfoDataBuffer()), (LongPointer) AtomicAllocator.getInstance().getPointer(op.dimensions().shapeInfoDataBuffer(), context));
                break;
            default:
                throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, null, st);

        return op.z();
    }

    /**
     *
     * @param op
     * @param dimension
     * @return
     */
    protected INDArray naiveExec(ReduceOp op, long... dimension) {
        long st = profilingConfigurableHookIn(op);

        if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()) {
            //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y]
            //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions"
            if(op.z() != null){
                Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." +
                        " Got: x=%ndShape, z=%ndShape", op.x(), op.z());
                op.z().assign(op.x());
                return op.z();
            } else {
                op.setZ(op.x().dup());
                return op.z();
            }
        }

        INDArray ret = op.z();

        checkForCompression(op);
        op.validateDataTypes(null);

        for (int i = 0; i < dimension.length; i++)
            if (dimension[i] >= op.x().rank() && dimension[i] != Integer.MAX_VALUE)
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension)
                        + " contains element that higher then rank of op.X: [" + op.x().rank() + "]");

        val context = AtomicAllocator.getInstance().getDeviceContext();

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());

        Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);

        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);

        DataBuffer offsets = tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);

        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);

        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        PointerPointer xShapeInfoHostPointer = extraz.get().put(
                AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()),
                context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer(),
                context.getBufferAllocation(),
                context.getBufferReduction(),
                context.getBufferScalar(),
                context.getBufferSpecial(),
                hostYShapeInfo,
                hostZShapeInfo,
                hostTadShapeInfo,
                devTadShapeInfo,
                devTadOffsets);

        Pointer yDevTadOffsets = null;
        Pointer yDevTadShapeInfo = null;

        if (op.y() != null) {
            if (dimension.length == 0 || (dimension.length == 1 &&  dimension[0] == Integer.MAX_VALUE )|| op.x().tensorAlongDimension(0, dimension).length() != op.y().length()) {
                if (!op.isComplexAccumulation() && op.x().length() != op.y().length())
                    throw new ND4JIllegalStateException("Op.X [" + op.x().length() + "] and Op.Y [" + op.y().length() + "] lengths should match");

                if (!op.z().isScalar()) {
                    Pair<DataBuffer, DataBuffer> yTadBuffers = tadManager.getTADOnlyShapeInfo(op.y(), dimension);

                    yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer(yTadBuffers.getFirst(), context);

                    DataBuffer yOffsets = yTadBuffers.getSecond();
                    yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);

                    xShapeInfoHostPointer.put(12, yDevTadShapeInfo);
                    xShapeInfoHostPointer.put(13, yDevTadOffsets);
                }
            } else {
                // TAD vs full array code branch
                val fakeOffsets = Nd4j.getConstantHandler().getConstantBuffer(new int[] {0, 0}, DataType.LONG);
                yDevTadOffsets = fakeOffsets == null ? null : AtomicAllocator.getInstance().getPointer(fakeOffsets, context);

                yDevTadShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);

                xShapeInfoHostPointer.put(12, AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context));
                xShapeInfoHostPointer.put(13, null);
            }
        }

        DataType argsType;
        switch (op.getOpType()) {
            case REDUCE_LONG:
            case REDUCE_BOOL:
                argsType = op.x().dataType();
                break;
            default:
                argsType = op.z().dataType();
        }

        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(argsType), context) : null;
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context); //AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context);

        val x = op.x() == null ? null : op.x().data().opaqueBuffer();
        val y = op.y() == null ? null : op.y().data().opaqueBuffer();
        val z = op.z() == null ? null : op.z().data().opaqueBuffer();

        if (op instanceof Variance) {
            if (ret.isScalar()) {
                nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(),
                        x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        extraArgs,
                        z, (LongPointer) hostZShapeInfo,
                        (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()),
                        ((Variance) op).isBiasCorrected());
            } else {
                nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(),
                        x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        extraArgs,
                        z, (LongPointer) hostZShapeInfo,
                        (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
                        op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(),
                        (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
                        ((Variance) op).isBiasCorrected(),
                        (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets);
            }
        } else if (op.y() != null) {
            if (ret.isScalar()) {
                nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(),
                        x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        extraArgs,
                        y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context),
                        z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context));
            } else {
                nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(),
                        x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        extraArgs,
                        y, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context),
                        z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
                        ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
                        (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo, (LongPointer) yDevTadOffsets);
            }
        } else {
            if (ret.isScalar()) {
                switch (op.getOpType()) {
                    case REDUCE_FLOAT:
                        nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(),
                                x, (LongPointer) hostXShapeInfo,(LongPointer) xShapeInfo,
                                extraArgs,
                                z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                        break;
                    case REDUCE_BOOL:
                        nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(),
                                x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                        break;
                    case REDUCE_LONG:
                        nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(),
                                x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                        break;
                    case REDUCE_SAME:
                        nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(),
                                x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                z, (LongPointer) hostZShapeInfo,(LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer()));
                        break;
                    default:
                        throw new UnsupportedOperationException();
                }
            } else {
                switch (op.getOpType()) {
                    case REDUCE_FLOAT:
                        nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(),
                                x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
                                ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                        break;
                    case REDUCE_BOOL:
                        nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(),
                                x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
                                ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                        break;
                    case REDUCE_SAME:
                        nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(),
                                x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
                                ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                        break;
                    case REDUCE_LONG:
                        nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(),
                                x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                z, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context),
                                ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                        break;
                    default:
                        throw new UnsupportedOperationException();
                }
            }
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, null, st);

        return op.z();
    }

    @Override
    public INDArray exec(Variance op) {
        return exec((ReduceOp) op);
    }

    @Override
    public INDArray exec(ReduceOp op) {
        checkForCompression(op);

        if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()){
            //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y]
            //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions"
            if(op.z() != null){
                Preconditions.checkState(op.x().equalShapes(op.z()), "For empty reductions, result (z) array must have same shape as x shape." +
                        " Got: x=%ndShape, z=%ndShape", op.x(), op.z());
                op.z().assign(op.x());
                return op.z();
            } else {
                op.setZ(op.x().dup());
                return op.z();
            }
        }

        val dimension = op.dimensions().toLongVector();

        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        val maxShape = Shape.getMaxShape(op.x(),op.y());

        val wholeDims = Shape.wholeArrayDimension(dimension) || op.x().rank() == dimension.length || dimension.length == 0;
        val retShape = Shape.reductionShape(op.y() == null ? op.x() : op.x().length() > op.y().length() ? op.x() : op.y(), dimension, true, op.isKeepDims());

        if (op.x().isVector() && op.x().length() == ArrayUtil.prod(retShape) && ArrayUtil.prodLong(retShape) > 1 && op.y() == null)
            return op.noOp();

        val dtype = op.resultType();
        INDArray ret = null;
        if (op.z() == null || op.z() == op.x()) {
            if (op.isComplexAccumulation()) {
                val xT = op.x().tensorsAlongDimension(dimension);
                val yT = op.y().tensorsAlongDimension(dimension);

                // we intentionally want to set it to 0.0
                ret = Nd4j.createUninitialized(dtype, new long[] {xT, yT});
            } else {
                if (op.y() != null) {
                    //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y
                    if (op.x().length() == op.y().length()) {
                        //Pairwise
                        if (!wholeDims && op.x().tensorsAlongDimension(dimension) != op.y().tensorsAlongDimension(dimension)) {
                            throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " +
                                    Arrays.toString(op.x().shape()) + ", y shape = " + Arrays.toString(op.y().shape()) +
                                    ", dimension = " + Arrays.toString(dimension) + ")");
                        }
                    } else {
                        if (dimension.length == 0)
                            throw new ND4JIllegalStateException("TAD vs TAD comparison requires dimension (or other comparison mode was supposed to be used?)");

                        //Every X TAD vs. entirety of Y
                        val xTADSize = op.x().length() / op.x().tensorsAlongDimension(dimension);

                        if (xTADSize != op.y().length()) {
                            throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
                                    " (x TAD size = " + xTADSize + ", y size = " + op.y().length());
                        }
                    }
                }

                // in case of regular accumulation we don't care about array state before op
                ret = Nd4j.create(dtype, retShape);
            }
            op.setZ(ret);
        } else {
            // compare length

            if (op.z().length() != (retShape.length == 0 ? 1 : ArrayUtil.prodLong(retShape)))
                throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(op.z().shape()) + "] doesn't match expected [" + Arrays.toString(retShape) + "]");
        }

        long st = profilingConfigurableHookIn(op);
        naiveExec(op, dimension);

        profilingConfigurableHookOut(op, null, st);

        return op.z();
    }

    @Override
    public INDArray exec(IndexAccumulation op) {
        val dimension = Shape.normalizeAxis(op.x().rank(), op.dimensions().toLongVector());

        if (op.x().isEmpty()) {
            for (val d:dimension) {
                Preconditions.checkArgument(op.x().size(d) != 0, "IndexReduce can't be issued along axis with 0 in shape");
            }
        }

        if (op.z() == null) {
            val retShape = Shape.reductionShape(op.x(), dimension, true, op.isKeepDims());
            op.setZ(Nd4j.createUninitialized(DataType.LONG, retShape));
        }

        long st = profilingConfigurableHookIn(op);

        checkForCompression(op);


        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        if (op.x().isVector() && op.x().length() == op.z().length()) {
            return op.x();
        }

        if (op.z().isEmpty())
            return op.z();

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        val context = AtomicAllocator.getInstance().getDeviceContext();

        val hostXShapeInfo =
                op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        val hostYShapeInfo =
                op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        val hostZShapeInfo =
                op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());

        val xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);

        Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);

        val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
        val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);

        val offsets = tadBuffers.getSecond();
        val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);

        PointerPointer xShapeInfoHostPointer = extraz.get().put(
                AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
                context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
                hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets);
        Pointer extraArgs = op.extraArgs() != null
                ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.x().dataType()), context) : null;

        val x = op.x() == null ? null : op.x().data().opaqueBuffer();
        val y = op.y() == null ? null : op.y().data().opaqueBuffer();
        val z = op.z() == null ? null : op.z().data().opaqueBuffer();

        nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(),
                x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                extraArgs,
                z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, null, st);

        return op.z();
    }


    @Override
    public INDArray exec(Op op) {
        return exec(op, null);
    }

    @Override
    public INDArray exec(Op op, OpContext oc) {
        checkForCompression(op);

        if (op instanceof TransformOp) {
            TransformOp t = (TransformOp) op;
            invoke(t, oc);
        } else if (op instanceof ReduceOp) {
            ReduceOp acc = (ReduceOp) op;
            invoke(acc, oc, acc.dimensionsArr());
        } else if (op instanceof ScalarOp) {
            ScalarOp sc = (ScalarOp) op;
            invoke(sc, oc);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp) op;
            invoke(broadcastOp, oc);
        } else if (op instanceof IndexAccumulation) {
            IndexAccumulation indexAccumulation = (IndexAccumulation) op;
            invoke(indexAccumulation, oc, indexAccumulation.dimensions().toLongVector());
        } else if (op instanceof RandomOp) {
            exec((RandomOp) op, oc, Nd4j.getRandom());
        } else if (op instanceof CustomOp) {
            exec((CustomOp) op, oc);
        }


        return op.z();
    }


    @Override
    public TransformOp execAndReturn(TransformOp op) {
        checkForCompression(op);
        invoke(op, null);
        return op;
    }



    protected CudaContext invoke(BroadcastOp op, OpContext oc) {
        long st = profilingConfigurableHookIn(op);

        INDArray x = getX(op, oc);
        INDArray y = getY(op, oc);
        INDArray z = getZ(op, oc);

        checkForCompression(op);


        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        val context = AtomicAllocator.getInstance().getDeviceContext();

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);


        val hostXShapeInfo =
                x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        val hostYShapeInfo =
                y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        val hostZShapeInfo =
                z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());

        val tadBuffers = tadManager.getTADOnlyShapeInfo(x, op.getDimension());

        val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
        val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);

        val offsets = tadBuffers.getSecond();
        val devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);

        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;

        // that's the place where we're going to have second TAD in place
        val tadBuffersZ = tadManager.getTADOnlyShapeInfo(z, op.getDimension());

        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getSecond(), context);

        PointerPointer xShapeInfoHostPointer = extraz.get().put(
                AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), // 0
                context.getOldStream(), // 1
                AtomicAllocator.getInstance().getDeviceIdPointer(), // 2
                context.getBufferAllocation(), // 3
                context.getBufferReduction(),  // 4
                context.getBufferScalar(),  // 5
                context.getBufferSpecial(), // 6
                hostYShapeInfo,  // 7
                hostZShapeInfo,  // 8
                hostTadShapeInfo,  // 9
                devTadShapeInfo,  // 10
                devTadOffsets, // 11
                devTadShapeInfoZ,  // 12
                devTadOffsetsZ); // 13

        Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);

        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(op.getDimension()), context);

        val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
        val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();


        switch (op.getOpType()) {
            case BROADCAST:
                nativeOps.execBroadcast(xShapeInfoHostPointer, op.opNum(),
                        xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
                        zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                        op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(),
                        (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                break;
            case BROADCAST_BOOL:
                nativeOps.execBroadcastBool(xShapeInfoHostPointer, op.opNum(),
                        xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
                        zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                        null,
                        op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                break;
            default:
                throw new UnsupportedOperationException("Unknown opType: " + op.getOpType());
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, oc, st);

        return null;
    }



    protected CudaContext invoke(IndexAccumulation op, OpContext oc, long[] dimension) {
        INDArray x = getX(op, oc);
        INDArray y = getY(op, oc);
        INDArray z = getZ(op, oc);

        dimension = Shape.normalizeAxis(x.rank(), dimension);
        if (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE)) {
            if(z == x || z == null) {
                z = Nd4j.createUninitialized(DataType.LONG, new long[0], 'c');
                setZ(z, op, oc);
            }
        }

        boolean keepDims = op.isKeepDims();
        long[] retShape = Shape.reductionShape(x, dimension, true, keepDims);

        if(z == null || x == z) {
            val ret = Nd4j.createUninitialized(DataType.LONG, retShape);

            setZ(ret, op, oc);
            z = ret;
        } else if(!Arrays.equals(retShape, z.shape())){
            throw new IllegalStateException("Z array shape does not match expected return type for op " + op
                    + ": expected shape " + Arrays.toString(retShape) + ", z.shape()=" + Arrays.toString(z.shape()));
        }

        long st = profilingConfigurableHookIn(op);

        checkForCompression(op);


        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());
        CudaEnvironment.getInstance().getConfiguration().enableDebug(true);
        if (dimension != null)
            for (int i = 0; i < dimension.length; i++)
                if (dimension[i] >= x.rank() && dimension[i] != Integer.MAX_VALUE)
                    throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension) + " contains element that higher then rank of op.X: [" + x.rank() + "]");

        val context = AtomicAllocator.getInstance().getDeviceContext();

        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(x.dataType()), context) : null;

        val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());

        long fdimension[] = dimension;
        if (fdimension == null)
            fdimension = new long[] {0};

        Pair<DataBuffer, DataBuffer> tadBuffers = tadManager.getTADOnlyShapeInfo(x, fdimension);

        Pointer hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
        Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);

        DataBuffer offsets = tadBuffers.getSecond();
        Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context);
        val zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);

        val xb = op.x() == null ? null : op.x().data().opaqueBuffer();
        val zb = op.z() == null ? null : op.z().data().opaqueBuffer();

        PointerPointer xShapeInfoHostPointer = extraz.get().put(
                AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
                context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
                hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets);

        if (z.isScalar() || dimension == null || dimension[0] == Integer.MAX_VALUE) {
            nativeOps.execIndexReduceScalar(xShapeInfoHostPointer, op.opNum(),
                    xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                    extraArgs,
                    zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
        } else {
            if (dimension != null && dimension.length > 1)
                Arrays.sort(dimension);


            nativeOps.execIndexReduce(xShapeInfoHostPointer, op.opNum(),
                    xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                    extraArgs,
                    zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                    ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, oc, st);

        return null;

    }


    protected CudaContext invoke(ReduceOp op, OpContext oc, long[] dimension) {
        val context = AtomicAllocator.getInstance().getDeviceContext();

        INDArray x = getX(op, oc);
        INDArray y = getY(op, oc);
        INDArray z = getZ(op, oc);

        if(op instanceof BaseReduceOp && ((BaseReduceOp)op).isEmptyReduce()) {
            //Edge case for TF import compatibility: [x,y].reduce(empty) = [x,y]
            //Note that "empty" axis is NOT the same as length 0, as in INDArray.sum(new int[0]), which means "all dimensions"
            if(z != null) {
                if(!x.isScalar() && !z.isScalar())
                    Preconditions.checkState(x.equalShapes(z), "For empty reductions, result (z) array must have same shape as x shape." +
                            " Got: x=%ndShape, z=%ndShape", x, z);
                z.assign(x);
                return context;
            } else {
                setZ(x.dup(), op, oc);
                return context;
            }
        }

        // FIXME: this should be moved down to C++ on per-op basis
        // reduce to scalar case, ReduceBool ops require special treatment
        if (op instanceof BaseReduceBoolOp && x.isEmpty() && (dimension == null || (dimension.length == 1 && dimension[0] == Integer.MAX_VALUE))) {
            if (z == null) {
                op.setZ(Nd4j.scalar(((BaseReduceBoolOp) op).emptyValue()));
            } else {
                z.assign(((BaseReduceBoolOp) op).emptyValue());
            }

            return context;
        }

        long st = profilingConfigurableHookIn(op);

        checkForCompression(op);

        dimension = Shape.normalizeAxis(x.rank(), dimension);


        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        // dimension is ALWAYS null here.
        if (dimension == null )
            dimension = new long[] {Integer.MAX_VALUE};

        if (dimension != null && dimension.length > 1)
            Arrays.sort(dimension);

        for (int i = 0; i < dimension.length; i++)
            if (dimension[i] >= x.rank() && dimension[i] != Integer.MAX_VALUE)
                throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension)
                        + " contains element that higher then rank of op.X: [" + x.rank() + "]");

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        val tadBuffers = x.isEmpty() ? Pair.<DataBuffer, DataBuffer>makePair(x.data(), null) : tadManager.getTADOnlyShapeInfo(x, dimension);

        val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
        val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);

        val offsets = x.isEmpty() ? null : tadBuffers.getSecond();
        val devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer((DataBuffer) offsets, context);

        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);

        long[] retShape = Shape.reductionShape(x, dimension, true, op.isKeepDims());

        if (y != null) {
            //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y
            if (x.length() == y.length()) {
                //Pairwise
                if (x.tensorsAlongDimension(dimension) != y.tensorsAlongDimension(dimension)) {
                    throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " +
                            Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) +
                            ", dimension = " + Arrays.toString(dimension) + ")");
                }
            } else if(!(op instanceof ReduceOp)) {
                //Every X TAD vs. entirety of Y
                val xTADSize = x.length() / x.tensorsAlongDimension(dimension);

                if (xTADSize != y.length()) {
                    throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution:" +
                            " (x TAD size = " + xTADSize + ", y size = " + y.length());
                }
            }
        }



        val dataType = oc != null ? op.resultType(oc) : op.resultType();

        if( z == null ){
            val ret = Nd4j.createUninitialized(dataType, retShape);
            setZ(ret, op, oc);
            z = ret;
        } else if(z.dataType() != dataType || !Arrays.equals(retShape, z.shape())){
            throw new ND4JIllegalStateException("Output array for op " + op.getClass().getSimpleName() + " should have type " + dataType + " and shape " + Arrays.toString(retShape)
                    + " but has datatype " + z.dataType() + " and shape " + Arrays.toString(z.shape()));
        }

        val eb = op.extraArgsDataBuff(z.dataType() == DataType.BOOL || op.getOpType() == Op.Type.REDUCE_LONG ? x.dataType() : z.dataType());
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(eb, context) : null;

        val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());

        val xShapeInfoHostPointer = extraz.get().put(
                AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
                context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
                hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets);

        val yTadBuffers = y == null ? null : tadManager.getTADOnlyShapeInfo(y, dimension);

        val yDevTadShapeInfo = y == null ? null : AtomicAllocator.getInstance().getPointer(yTadBuffers.getFirst(), context);
        val yOffsets = y == null ? null : yTadBuffers.getSecond();
        val yDevTadOffsets = yOffsets == null ? null : AtomicAllocator.getInstance().getPointer(yOffsets, context);

        if (y != null) {
            xShapeInfoHostPointer.put(12L, yDevTadShapeInfo);
            xShapeInfoHostPointer.put(13L, yDevTadOffsets);
        }

        val zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);

        val xb = x == null ? null : x.data().opaqueBuffer();
        val yb = y == null ? null : y.data().opaqueBuffer();
        val zb = z == null ? null : z.data().opaqueBuffer();

        op.validateDataTypes(null);

        if (z.isScalar()) {
            if (op instanceof Variance) {
                nativeOps.execSummaryStatsScalar(xShapeInfoHostPointer, op.opNum(),
                        xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        extraArgs,
                        zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                        ((Variance) op).isBiasCorrected());
            } else if (y != null) {
                Pointer yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);
                nativeOps.execReduce3Scalar(xShapeInfoHostPointer, op.opNum(),
                        xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        extraArgs,
                        yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
                        zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
            } else {
                switch (op.getOpType()) {
                    case REDUCE_FLOAT:
                        nativeOps.execReduceFloat(xShapeInfoHostPointer, op.opNum(),
                                xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
                        break;
                    case REDUCE_BOOL:
                        nativeOps.execReduceBool(xShapeInfoHostPointer, op.opNum(),
                                xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
                        break;
                    case REDUCE_SAME:
                        nativeOps.execReduceSame(xShapeInfoHostPointer, op.opNum(),
                                xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
                        break;
                    case REDUCE_LONG:
                        nativeOps.execReduceLong(xShapeInfoHostPointer, op.opNum(),
                                xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                extraArgs,
                                zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo);
                        break;
                    default:
                        throw new UnsupportedOperationException();
                }
            }
        } else {
            val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);

            if (y != null) {
                val yShapeInfo = AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context);
                nativeOps.execReduce3Tad(xShapeInfoHostPointer, op.opNum(),
                        xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        extraArgs,
                        yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
                        zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                        op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(),
                        (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
                        (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets, (LongPointer) yDevTadShapeInfo,
                        (LongPointer) yDevTadOffsets);
            } else {
                if (op instanceof Variance) {
                    nativeOps.execSummaryStatsTad(xShapeInfoHostPointer, op.opNum(),
                            xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                            extraArgs,
                            zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                            op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(),
                            (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
                            ((Variance) op).isBiasCorrected(),
                            (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets);
                } else {
                    switch (op.getOpType()) {
                        case REDUCE_FLOAT:
                            nativeOps.execReduceFloat2(xShapeInfoHostPointer, op.opNum(),
                                    xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                    extraArgs,
                                    zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                                    op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(),
                                    (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                            break;
                        case REDUCE_SAME:
                            nativeOps.execReduceSame2(xShapeInfoHostPointer, op.opNum(),
                                    xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                    extraArgs,
                                    zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                                    ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                            break;
                        case REDUCE_BOOL:
                            nativeOps.execReduceBool2(xShapeInfoHostPointer, op.opNum(),
                                    xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                    extraArgs,
                                    zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                                    ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                            break;
                        case REDUCE_LONG:
                            nativeOps.execReduceLong2(xShapeInfoHostPointer, op.opNum(),
                                    xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                                    extraArgs,
                                    zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                                    ((BaseCudaDataBuffer) op.dimensions().castTo(DataType.LONG).data()).getOpaqueDataBuffer(), (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null);
                            break;
                        default:
                            throw new UnsupportedOperationException();
                    }
                }
            }
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, oc, st);

        Nd4j.getExecutioner().commit();

        return context;
    }


    protected CudaContext intercept(ScalarOp op, long[] dimension) {
        long st = profilingConfigurableHookIn(op);

        if (dimension != null && dimension.length > 1)
            Arrays.sort(dimension);

        val context = AtomicAllocator.getInstance().getDeviceContext();

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer());
        val hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer());
        val hostZShapeInfo = op.z() == null ? null : AddressRetriever.retrieveHostPointer(op.z().shapeInfoDataBuffer());

        val xShapeInfo = AtomicAllocator.getInstance().getPointer(op.x().shapeInfoDataBuffer(), context);
        val yShapeInfo = AtomicAllocator.getInstance().getPointer(op.y().shapeInfoDataBuffer(), context);
        val zShapeInfo = AtomicAllocator.getInstance().getPointer(op.z().shapeInfoDataBuffer(), context);

        val tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), dimension);

        val hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst());
        val devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context);

        val offsets = tadBuffers.getSecond();
        val devTadOffsets = AtomicAllocator.getInstance().getPointer(offsets, context);

        Pointer devTadShapeInfoZ = null;
        Pointer devTadOffsetsZ = null;

        val tadBuffersZ = tadManager.getTADOnlyShapeInfo(op.z(), dimension);

        devTadShapeInfoZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getFirst(), context);
        devTadOffsetsZ = AtomicAllocator.getInstance().getPointer(tadBuffersZ.getSecond(), context);


        PointerPointer extraPointers = extraz.get().put(
                AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()), context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
                context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
                hostYShapeInfo, hostZShapeInfo, hostTadShapeInfo, devTadShapeInfo, devTadOffsets,
                devTadShapeInfoZ, devTadOffsetsZ);

        val extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.z().dataType()), context) : null;

        val dimensionPointer = AtomicAllocator.getInstance().getPointer(AtomicAllocator.getInstance().getConstantBuffer(dimension), context);

        val x = op.x() == null ? null : op.x().data().opaqueBuffer();
        val y = op.y() == null ? null : op.y().data().opaqueBuffer();
        val z = op.z() == null ? null : op.z().data().opaqueBuffer();

        switch (op.getOpType()) {
            case SCALAR:
                nativeOps.execScalarTad(extraPointers, op.opNum(),
                        x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                        y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
                        extraArgs,
                        op.dimensions().castTo(DataType.LONG).data().opaqueBuffer()
                        , (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
                        (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets,
                        (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ);
                break;
            case SCALAR_BOOL:
                nativeOps.execScalarBoolTad(extraPointers, op.opNum(),
                        x, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        z, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                        y, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
                        extraArgs,
                        op.dimensions().castTo(DataType.LONG).data().opaqueBuffer(),
                        (LongPointer) op.dimensions().shapeInfoDataBuffer().addressPointer(), null,
                        (LongPointer) devTadShapeInfo, (LongPointer) devTadOffsets,
                        (LongPointer) devTadShapeInfoZ, (LongPointer) devTadOffsetsZ);
                break;
            default:
                throw new UnsupportedOperationException();
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, null, st);

        return null;
    }

    @Override
    public INDArray exec(ScalarOp op) {
        invoke(op, null);
        return op.z();
    }

    protected CudaContext invoke(ScalarOp op, OpContext oc) {
        long st = profilingConfigurableHookIn(op);

        checkForCompression(op);

        INDArray x = getX(op, oc);
        INDArray y = getY(op, oc);
        INDArray z = getZ(op, oc);


        if(z == null){
            switch (op.getOpType()) {
                case SCALAR:
                    z = x.ulike();
                    setZ(x.ulike(), op, oc);
                    break;
                case SCALAR_BOOL:
                    z = Nd4j.createUninitialized(DataType.BOOL, x.shape());
                    setZ(z, op, oc);
                    break;
                default:
                    throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]");
            }
        }

        if (x.length() != z.length())
            throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: ["
                    + Arrays.toString(x.shapeInfoDataBuffer().asInt()) + "] != ["
                    + Arrays.toString(z.shapeInfoDataBuffer().asInt()) + "]");

        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        if (op.dimensions() != null) {
            intercept(op, op.dimensions().toLongVector());
            return null;
        }

        val context = AtomicAllocator.getInstance().getDeviceContext();

        val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        val hostYShapeInfo = op.scalar() == null ? null : AddressRetriever.retrieveHostPointer(op.scalar().shapeInfoDataBuffer());
        val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());

        Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context);
        Pointer extraArgs = op.extraArgs() != null ? AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.SCALAR_BOOL ? x.dataType() : z.dataType()), context) : null;

        Pointer zShapeInfo = AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context);

        PointerPointer xShapeInfoHostPointer = extraz.get().put(
                AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), context.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer(), context.getBufferAllocation(),
                context.getBufferReduction(), context.getBufferScalar(), context.getBufferSpecial(),
                hostYShapeInfo, hostZShapeInfo, null, null);

        val xb = x == null ? null : x.data().opaqueBuffer();
        val yb = op.scalar() == null ? null : op.scalar().data().opaqueBuffer();
        val zb = z == null ? null : z.data().opaqueBuffer();

        switch (op.getOpType()) {
            case SCALAR_BOOL:
                nativeOps.execScalarBool(xShapeInfoHostPointer, op.opNum(),
                        xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                        yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context),
                        extraArgs);
                break;
            case SCALAR:
                nativeOps.execScalar(xShapeInfoHostPointer, op.opNum(),
                        xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                        zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                        yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(op.scalar().shapeInfoDataBuffer(), context),
                        extraArgs);
                break;
            default:
                throw new UnsupportedOperationException("Unknown op type: " + op.getOpType());
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, oc, st);

        return null;
    }

    protected CudaContext invoke(TransformOp op, OpContext oc) {
        long st = profilingConfigurableHookIn(op);

        INDArray x = getX(op, oc);
        INDArray y = getY(op, oc);
        INDArray z = getZ(op, oc);

        checkForCompression(op);

        //validateDataType(Nd4j.dataType(), op);

        AtomicAllocator allocator = AtomicAllocator.getInstance();

        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        val context = allocator.getDeviceContext();

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        // special temp array for IsMax along dimension
        INDArray ret = null;

        Pointer xShapeInfo = allocator.getPointer(x.shapeInfoDataBuffer(), context);


        Pointer dimensionDevPointer = null;
        Pointer dimensionHostPointer = null;
        Pointer retPointer = null;
        Pointer retHostShape = null;
        int dimension[] = null;

        val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        var hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());


        if (z == null) {
            ret = Nd4j.createUninitialized(op.resultType(), x.shape(), x.ordering());
            setZ(ret, op, oc);
            z = ret;
        }

        var extraArgs = op.extraArgs() != null ? allocator.getPointer(op.extraArgsDataBuff(op.getOpType() == Op.Type.TRANSFORM_BOOL || op.getOpType() == Op.Type.PAIRWISE_BOOL ? x.dataType() : z.dataType()), context) : null;
        val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());

        Pointer hostTadShapeInfo = null;
        Pointer devTadShapeInfo = null;

        Pointer hostMaxTadShapeInfo = null;
        Pointer devMaxTadShapeInfo = null;

        Pair<DataBuffer, DataBuffer> tadBuffers;
        Pair<DataBuffer, DataBuffer> tadMaxBuffers;

        Pointer devTadOffsets = null;
        Pointer devMaxTadOffsets = null;

        op.validateDataTypes(oc, experimentalMode.get());

        Pointer zShapeInfo = allocator.getPointer(z.shapeInfoDataBuffer(), context);


        PointerPointer xShapeInfoHostPointer =
                extraz.get().put(AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()), // 0
                        (Pointer) context.getOldStream(), // 1
                        allocator.getDeviceIdPointer(), // 2
                        context.getBufferAllocation(), // 3
                        context.getBufferReduction(), // 4
                        context.getBufferScalar(), // 5
                        context.getBufferSpecial(), // 6
                        (Pointer) hostYShapeInfo, // 7
                        (Pointer) hostZShapeInfo, // 8
                        hostTadShapeInfo, // 9
                        devTadShapeInfo, // 10
                        devTadOffsets, // 11
                        hostMaxTadShapeInfo, // 12
                        devMaxTadShapeInfo, // 13
                        devMaxTadOffsets, // 14
                        dimensionDevPointer, // special pointer for IsMax  // 15
                        dimensionHostPointer, // special pointer for IsMax  // 16
                        retPointer, // special pointer for IsMax // 17
                        (Pointer) new CudaPointer(dimension == null ? 0 : dimension.length),
                        retHostShape);


        val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
        val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();

        if (y != null) {
            Pointer yShapeInfo = allocator.getPointer(y.shapeInfoDataBuffer(), context);

            switch (op.getOpType()) {
                case TRANSFORM_BOOL:
                case PAIRWISE_BOOL:
                    nativeOps.execPairwiseTransformBool(xShapeInfoHostPointer, op.opNum(),
                            xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                            yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
                            zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                            extraArgs);
                    break;
                default:
                    nativeOps.execPairwiseTransform(xShapeInfoHostPointer, op.opNum(),
                            xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                            yb, (LongPointer) hostYShapeInfo, (LongPointer) yShapeInfo,
                            zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                            extraArgs);
                    break;
            }
        } else {
            switch (op.getOpType()) {
                case TRANSFORM_ANY:
                    nativeOps.execTransformAny(xShapeInfoHostPointer, op.opNum(),
                            xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                            zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                            extraArgs);
                    break;
                case TRANSFORM_FLOAT:
                    nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(),
                            xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                            zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                            extraArgs);
                    break;
                case TRANSFORM_BOOL:
                    nativeOps.execTransformBool(xShapeInfoHostPointer, op.opNum(),
                            xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                            zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                            extraArgs);
                    break;
                case TRANSFORM_SAME:
                    nativeOps.execTransformSame(xShapeInfoHostPointer, op.opNum(),
                            xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                            zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                            extraArgs);
                    break;
                case TRANSFORM_STRICT:
                    nativeOps.execTransformStrict(xShapeInfoHostPointer, op.opNum(),
                            xb, (LongPointer) hostXShapeInfo, (LongPointer) xShapeInfo,
                            zb, (LongPointer) hostZShapeInfo, (LongPointer) zShapeInfo,
                            extraArgs);
                    break;
                default:
                    throw new UnsupportedOperationException();
            }
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        if (extraArgs != null)
            extraArgs.address();

        if (ret != null)
            ret.elementWiseStride();

        profilingConfigurableHookOut(op, oc, st);

        return null;
    }

    protected <T extends Aggregate> DataBuffer getBuffer(Batch<T> batch) {
        DataBuffer buffer = Nd4j.getDataBufferFactory().createInt(batch.getSample().getRequiredBatchMemorySize() * 4,
                false);
        batch.setParamsSurface(buffer);
        return buffer;
    }

    @Override
    public <T extends Aggregate> void exec(Batch<T> batch) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    @Override
    public void exec(List<Aggregate> batch) {
        if (batch.size() == 0)
            return;

        List<Batch<Aggregate>> batches = Batch.getBatches(batch, 8192);
        for (Batch<Aggregate> single : batches) {
            this.exec(single);
        }

        val context = AtomicAllocator.getInstance().getDeviceContext();
        context.syncOldStream();
    }

    @Override
    public void exec(Aggregate op) {
        throw new UnsupportedOperationException("Pew-pew");
    }

    /**
     * This method executes specified RandomOp using default RNG available via Nd4j.getRandom()
     *
     * @param op
     */
    @Override
    public INDArray exec(RandomOp op) {
        return exec(op, Nd4j.getRandom());
    }


    @Override
    public INDArray exec(RandomOp op, Random rng) {
        return exec(op, null, rng);
    }

    public INDArray exec(RandomOp op, OpContext oc, Random rng) {
        INDArray x = getX(op, oc);
        INDArray y = getY(op, oc);
        INDArray z = getZ(op, oc);

        if(op instanceof BaseRandomOp && ((BaseRandomOp)op).isTripleArgRngOp() && z != null && x == null && y == null){
            //Ugly hack to ensure the triple arg call occurs
            //See GaussianDistribution.setZ etc
            x = z;
            y = z;
        }

        long st = profilingConfigurableHookIn(op);

        checkForCompression(op);


        if (rng.getStatePointer() == null)
            throw new IllegalStateException(
                    "You should use one of NativeRandom classes for NativeOperations execution");

        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        if (CudaEnvironment.getInstance().getConfiguration().isDebug())
            lastOp.set(op.opName());

        val context = AtomicAllocator.getInstance().getDeviceContext();

        PointerPointer extraZZ = extraz.get().put(AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()),
                context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer());

        val hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer());
        val hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer());
        val hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer());

        val xb = x == null ? null : ((BaseCudaDataBuffer) x.data()).getOpaqueDataBuffer();
        val yb = y == null ? null : ((BaseCudaDataBuffer) y.data()).getOpaqueDataBuffer();
        val zb = z == null ? null : ((BaseCudaDataBuffer) z.data()).getOpaqueDataBuffer();

        if (x != null && y != null && z != null) {
            // triple arg call
            nativeOps.execRandom3(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
                    xb, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context),
                    yb, (LongPointer) hostYShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(y.shapeInfoDataBuffer(), context),
                    zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
                    AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));

        } else if (x != null && z != null) {
            //double arg call
            nativeOps.execRandom2(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
                    xb, (LongPointer) hostXShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context),
                    zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
                    AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()),context));


        } else {
            // single arg call
            nativeOps.execRandom(extraZZ, op.opNum(), rng.getStatePointer(), // rng state ptr
                    zb, (LongPointer) hostZShapeInfo, (LongPointer) AtomicAllocator.getInstance().getPointer(z.shapeInfoDataBuffer(), context),
                    AtomicAllocator.getInstance().getPointer(op.extraArgsDataBuff(z.dataType()), context));
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        profilingConfigurableHookOut(op, oc, st);

        return z;
    }

    /**
     * This method return set of key/value
     * and key/key/value objects,
     * describing current environment
     *
     * @return
     */
    @Override
    public synchronized Properties getEnvironmentInformation() {
        if (properties == null) {
            Properties props = super.getEnvironmentInformation();

            List<Map<String, Object>> devicesList = new ArrayList<>();

            // fill with per-device information: name, memory, versions
            for (int i = 0; i < nativeOps.getAvailableDevices(); i++) {
                Map<String, Object> deviceProps = new HashMap<>();

                deviceProps.put(Nd4jEnvironment.CUDA_DEVICE_NAME_KEY, nativeOps.getDeviceName(i));
                deviceProps.put(Nd4jEnvironment.CUDA_FREE_MEMORY_KEY, nativeOps.getDeviceFreeMemory(i));
                deviceProps.put(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY, nativeOps.getDeviceTotalMemory(i));
                deviceProps.put(Nd4jEnvironment.CUDA_DEVICE_MAJOR_VERSION_KEY, (long) nativeOps.getDeviceMajor(i));
                deviceProps.put(Nd4jEnvironment.CUDA_DEVICE_MINOR_VERSION_KEY, (long) nativeOps.getDeviceMinor(i));

                devicesList.add(i, deviceProps);
            }

            // fill with basic general info
            props.put(Nd4jEnvironment.BACKEND_KEY, "CUDA");
            props.put(Nd4jEnvironment.CUDA_NUM_GPUS_KEY, nativeOps.getAvailableDevices());
            props.put(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY, devicesList);
            props.put(Nd4jEnvironment.BLAS_VENDOR_KEY, (Nd4j.factory().blas()).getBlasVendor().toString());
            props.put(Nd4jEnvironment.HOST_FREE_MEMORY_KEY, Pointer.maxBytes() - Pointer.totalBytes());

            // fill bandwidth information
            props.put(Nd4jEnvironment.MEMORY_BANDWIDTH_KEY, PerformanceTracker.getInstance().getCurrentBandwidth());

            properties = props;
        } else {

            List<Map<String, Object>> devicesList = (List<Map<String, Object>>) properties.get(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY);

            // just update information that might change over time
            for (int i = 0; i < nativeOps.getAvailableDevices(); i++) {
                Map<String, Object> dev = devicesList.get(i);

                dev.put(Nd4jEnvironment.CUDA_FREE_MEMORY_KEY, nativeOps.getDeviceFreeMemory(i));
                dev.put(Nd4jEnvironment.CUDA_TOTAL_MEMORY_KEY, nativeOps.getDeviceTotalMemory(i));
            }

            properties.put(Nd4jEnvironment.CUDA_DEVICE_INFORMATION_KEY, devicesList);
            properties.put(Nd4jEnvironment.HOST_FREE_MEMORY_KEY, Pointer.maxBytes() - Pointer.totalBytes());

            // fill bandwidth information
            properties.put(Nd4jEnvironment.MEMORY_BANDWIDTH_KEY, PerformanceTracker.getInstance().getCurrentBandwidth());
        }
        return properties;
    }

    @Override
    public TADManager getTADManager() {
        return tadManager;
    }

    @Override
    @SuppressWarnings("unchecked")
    public void printEnvironmentInformation() {
        super.printEnvironmentInformation();
    }

    @Override
    public void commit() {
        val ctx = AtomicAllocator.getInstance().getDeviceContext();
        ctx.syncOldStream();
        ctx.syncSpecialStream();
    }

    @Override
    public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
        if(customOps == null) {
            String list = nativeOps.getAllCustomOps();

            if (list == null || list.isEmpty()) {
                log.warn("No customs ops available!");
                customOps = Collections.emptyMap();
                return customOps;
            }

            val map = new HashMap<String, CustomOpDescriptor>();

            String[] split = list.split(";");
            for (String op : split) {
                if (op == null || op.isEmpty())
                    continue;

                String[] another = op.split(":");

                CustomOpDescriptor descriptor = CustomOpDescriptor.builder()
                        .hash(Long.valueOf(another[1]))
                        .numInputs(Integer.valueOf(another[2]))
                        .numOutputs(Integer.valueOf(another[3]))
                        .allowsInplace(Integer.valueOf(another[4]) == 1)
                        .numTArgs(Integer.valueOf(another[5]))
                        .numIArgs(Integer.valueOf(another[6]))
                        .build();

                map.put(another[0], descriptor);
            }

            customOps = Collections.unmodifiableMap(map);
        }

        return customOps;
    }



    protected LongShapeDescriptor getShapeFromPointer(LongPointer ptr) {
        val rank = (int) ptr.get(0);

        val shape = new long[rank * 2 + 4];
        for (int i = 0; i < shape.length; i++) {
            shape[i] = ptr.get(i);
        }

        //val extras = ptr.get(Shape.shapeInfoLength(rank) - 3);
        val t = ArrayOptionsHelper.arrayType(shape);
        return LongShapeDescriptor.fromShape(Shape.shape(shape), Shape.stride(shape), Shape.elementWiseStride(shape), Shape.order(shape), ArrayOptionsHelper.dataType(shape), t == ArrayType.EMPTY);
    }

    @Override
    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp op) {
        return calculateOutputShape(op, null);
    }

    @Override
    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp op, OpContext opContext) {

        Nd4j.getExecutioner().commit();

        val lc = op.opName().toLowerCase();
        val hash = op.opHash();

        val result = new ArrayList<LongShapeDescriptor>();
        int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments();
        if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) {
            if(log.isTraceEnabled()){
                log.trace("Could not calculate output shape for op {}: number of input args was 0",
                        op.getClass().getName());
            }
            return Collections.emptyList();
        }

        val inputBuffers = new PointerPointer<>(nIn * 2);
        val inputShapes = new PointerPointer<>(nIn);

        val inputArgs = opContext != null ? opContext.getInputArrays() : op.inputArguments();
        int cnt = 0;
        for (val in: inputArgs) {
            // TODO: once we implement Context-based shape function call this method should be removed
            val loc = Nd4j.getAffinityManager().getActiveLocation(in);
            if (loc != AffinityManager.Location.DEVICE && loc != AffinityManager.Location.EVERYWHERE) {
                Nd4j.getAffinityManager().ensureLocation(in, AffinityManager.Location.DEVICE);
            }

            // NOT A TYPO: shape functions work on host side only
            if (!in.isEmpty()) {
                inputBuffers.put(cnt, in.data().addressPointer());
                inputBuffers.put(cnt + nIn, AtomicAllocator.getInstance().getPointer(in.data()));
            }

            inputShapes.put(cnt++, in.shapeInfoDataBuffer().addressPointer());
        }


        int nIArgs = opContext != null ? opContext.numIArguments() : op.numIArguments();
        val iArgs = nIArgs > 0 ? new LongPointer(nIArgs) : null;
        cnt = 0;
        if(opContext != null) {
            for (val i: opContext.getIArguments())
                iArgs.put(cnt++, i);
        } else {
            for (val i: op.iArgs())
                iArgs.put(cnt++, i);
        }


        int nTArgs = opContext != null ? opContext.numTArguments() : op.numTArguments();
        val tArgs = nTArgs > 0 ? new DoublePointer(nTArgs) : null;

        int nBArgs = opContext != null ? opContext.numBArguments() : op.numBArguments();
        val bArgs = nBArgs > 0 ? new BooleanPointer(nBArgs) : null;

        int nDArgs = opContext != null ? opContext.numDArguments() : op.numDArguments();
        val dArgs = nDArgs > 0 ? new IntPointer(nDArgs) : null;

        cnt = 0;
        if(opContext != null){
            for (val b: opContext.getBArguments())
                bArgs.put(cnt++, b);
        } else {
            for (val b: op.bArgs())
                bArgs.put(cnt++, b);
        }


        cnt = 0;
        if(opContext != null){
            for (val b: opContext.getTArguments())
                tArgs.put(cnt++, b);
        } else {
            for (val b: op.tArgs())
                tArgs.put(cnt++, b);
        }

        cnt = 0;
        if(opContext != null) {
            for (val b: opContext.getDArguments())
                dArgs.put(cnt++, b.toInt());
        } else {
            for (val b: op.dArgs())
                dArgs.put(cnt++, b.toInt());
        }

        OpaqueShapeList ptrptr = nativeOps.calculateOutputShapes2(null,
                hash, inputBuffers, inputShapes, nIn, tArgs, nTArgs,
                iArgs, nIArgs, bArgs, nBArgs, dArgs, nDArgs);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        if (ptrptr == null)
            throw new RuntimeException();

        for (int e = 0; e < nativeOps.getShapeListSize(ptrptr); e++ )
            result.add(getShapeFromPointer(new PagedPointer(nativeOps.getShape(ptrptr, e)).asLongPointer()));

        nativeOps.deleteShapeList(ptrptr);


        return result;
    }

    /**
     * This method executes given CustomOp
     *
     * PLEASE NOTE: You're responsible for input/output validation
     * PLEASE NOTE: right now this operations are executing on CPU
     * @param op
     */
    @Override
    public INDArray[] exec(CustomOp op) {

        Nd4j.getExecutioner().commit();

        boolean shapeOverride = false;
        if (op.numOutputArguments() == 0 && !op.isInplaceCall()) {
            try {
                val list = this.calculateOutputShape(op);
                if (list.isEmpty())
                    throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");

                for (val shape: list)
                    op.addOutputArgument(Nd4j.create(shape, false));

                shapeOverride = true;
            } catch (Exception e) {
                throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e);
            }
        }



        val name = op.opName();
        try (val context = (CudaOpContext) buildContext()) {
            // optionally skip shape validation on op execution
            if (shapeOverride)
                context.shapeFunctionOverride(true);

            context.markInplace(op.isInplaceCall());

            // transferring rng state
            context.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());

            //transferring input/output arrays
            context.setInputArrays(op.inputArguments());
            context.setOutputArrays(op.outputArguments());

            // transferring static args
            context.setBArguments(op.bArgs());
            context.setIArguments(op.iArgs());
            context.setTArguments(op.tArgs());
            context.setDArguments(op.dArgs());

            val result = exec(op, context);
            val states = context.getRngStates();


            // pulling states back
            Nd4j.getRandom().setStates(states.getFirst(), states.getSecond());

            return result;
        } catch (ND4JOpProfilerException e) {
            throw e;
        } catch (Exception e) {
            StringBuilder message = new StringBuilder();
            message.append("Op [" + name + "] execution failed with error " + "Cuda last error message: " + cudaGetErrorName(org.bytedeco.cuda.global.cublas.cublasGetError()).getString());
            throw new RuntimeException(message.toString(), e);
        }
    }

    @Override
    public void enableDebugMode(boolean reallyEnable) {
        debug.set(reallyEnable);
        nativeOps.enableDebugMode(reallyEnable);
    }

    @Override
    public void enableVerboseMode(boolean reallyEnable) {
        verbose.set(reallyEnable);
        nativeOps.enableVerboseMode(reallyEnable);
    }

    @Override
    public void registerGraph(long id, Pointer graph) {
        nativeOps.registerGraph(null, id, graph);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());
    }

    @Override
    public Map<String, INDArray> executeGraph(long id, @NonNull Map<String, INDArray> map, @NonNull Map<String, Integer> reverseMap) {

        Nd4j.getExecutioner().commit();

        val ptrBuffers = new PointerPointer(map.size() * 2);
        val ptrShapes = new PointerPointer(map.size() * 2);
        val ptrIndices = new IntPointer(map.size());

        int cnt = 0;
        val keySet = new ArrayList<>(map.keySet());
        for (val key: keySet) {
            val array = map.get(key);

            ptrBuffers.put(cnt, AtomicAllocator.getInstance().getHostPointer(array));
            ptrShapes.put(cnt, AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()));
            ptrIndices.put(cnt, reverseMap.get(key));

            cnt++;
        }

        val newMap = new LinkedHashMap<String, INDArray>();

        OpaqueVariablesSet result = nativeOps.executeStoredGraph(null, id, ptrBuffers, ptrShapes, ptrIndices, map.size());

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        OpStatus status = OpStatus.byNumber(nativeOps.getVariablesSetStatus(result));

        if (status != OpStatus.ND4J_STATUS_OK)
            throw new ND4JIllegalStateException("Op execution failed: " + status);

        for (int e = 0; e < nativeOps.getVariablesSetSize(result); e++) {
            OpaqueVariable var = nativeOps.getVariable(result, e);
            int nodeId = nativeOps.getVariableId(var);
            int index = nativeOps.getVariableIndex(var);
            LongPointer shapeInfo = nativeOps.getVariableShape(var);
            Pointer buffer = nativeOps.getVariableBuffer(var);

            val rank = (int) shapeInfo.get(0);
            val jshape = new long[rank * 2 + 4];
            for (int i = 0; i < jshape.length; i++) {
                jshape[i] = shapeInfo.get(i);
            }

            val shapeOf = Shape.shapeOf(jshape);
            val stridesOf = Shape.stridesOf(jshape);
            val order = Shape.order(jshape);
            val array = Nd4j.create(shapeOf, stridesOf, 0, order);

            Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * array.dataType().width());
            //AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
            if (1 > 0)
                throw new UnsupportedOperationException("Pew-pew");

            String nodeName = nativeOps.getVariableName(var);
            newMap.put(nodeName, array);
        }

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        nativeOps.deleteVariablesSet(result);

        return newMap;
    }

    @Override
    public void forgetGraph(long id) {
        nativeOps.unregisterGraph(null, id);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());
    }

    /**
     * This method allows to set desired number of elements per thread, for performance optimization purposes.
     * I.e. if array contains 2048 elements, and threshold is set to 1024, 2 threads will be used for given op execution.
     * <p>
     * Default value: 1024
     *
     * @param threshold
     */
    @Override
    public void setElementsThreshold(int threshold) {
        nativeOps.setElementThreshold(threshold);
    }

    /**
     * This method allows to set desired number of sub-arrays per thread, for performance optimization purposes.
     * I.e. if matrix has shape of 64 x 128, and threshold is set to 8, each thread will be processing 8 sub-arrays (sure, if you have 8 core cpu).
     * If your cpu has, say, 4, cores, only 4 threads will be spawned, and each will process 16 sub-arrays
     * <p>
     * Default value: 8
     *
     * @param threshold
     */
    @Override
    public void setTadThreshold(int threshold) {
        nativeOps.setTADThreshold(threshold);
    }


    @Override
    public ExecutionerType type() {
        return ExecutionerType.CUDA;
    }

    @Override
    public String getString(DataBuffer buffer, long index) {
        Preconditions.checkArgument(buffer instanceof CudaUtf8Buffer, "Expected Utf8Buffer");

        val addr = ((LongIndexer) buffer.indexer()).get(index);
        val ptr = new PagedPointer(addr);
        val str = new Nd4jCuda.utf8string(ptr);
        return str._buffer().capacity(str._length()).getString();
    }

    @Override
    public boolean isExperimentalMode() {
        return experimentalMode.get();
    }

    @Override
    public void scatterUpdate(ScatterUpdate.UpdateOp op, @NonNull INDArray array, @NonNull INDArray indices, @NonNull INDArray updates, long[] axis) {
        val context = AtomicAllocator.getInstance().getDeviceContext();

        val tadX = tadManager.getTADOnlyShapeInfo(array, axis);
        val tadY = tadManager.getTADOnlyShapeInfo(updates, axis);

        if (tadY.getSecond().length() != indices.length())
            throw new IllegalStateException("Number of updates doesn't match number of indices. Bad dimensions used?");

        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        val stuff = extraz.get().put(null, context.getOldStream());

        nativeOps.scatterUpdate(stuff, op.ordinal(), (int) indices.length(),
                null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(tadX.getFirst()), null, AtomicAllocator.getInstance().getPointer(array, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadX.getFirst()), (LongPointer) AtomicAllocator.getInstance().getPointer(tadX.getSecond()),
                null, (LongPointer) AtomicAllocator.getInstance().getHostPointer(tadY.getFirst()), null, AtomicAllocator.getInstance().getPointer(updates, context), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getFirst()), (LongPointer) AtomicAllocator.getInstance().getPointer(tadY.getSecond()),
                AtomicAllocator.getInstance().getHostPointer(indices), (LongPointer) AtomicAllocator.getInstance().getHostPointer(indices.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(indices, context), (LongPointer) AtomicAllocator.getInstance().getPointer(indices.shapeInfoDataBuffer(), context));

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());
    }

    @Override
    public OpContext buildContext() {
        return new CudaOpContext();
    }

    @Override
    public INDArray[] exec(CustomOp op, OpContext context) {
        Nd4j.getExecutioner().commit();
        long st = profilingConfigurableHookIn(op, context);
        if(op instanceof UserDefinedCustomOp) {
            ((UserDefinedCustomOp) op).exec(context);
            return context.getOutputArrays().toArray(new INDArray[0]);
        }



        val status = nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        if (status != 0)
            throw new RuntimeException("Op [" + op.opName() + "] execution failed");

        // check if input && output needs update
        for (val in:op.inputArguments()) {
            if (!in.isEmpty())
                ((BaseCudaDataBuffer) in.data()).actualizePointerAndIndexer();
        }

        for (val out:op.outputArguments()) {
            if (!out.isEmpty()) {
                ((BaseCudaDataBuffer) out.data()).actualizePointerAndIndexer();
                AtomicAllocator.getInstance().tickDeviceWrite(out);
            }

        }


        profilingConfigurableHookOut(op, context, st);

        if (context.getOutputArrays().isEmpty())
            return new INDArray[0];
        else
            return context.getOutputArrays().toArray(new INDArray[context.getOutputArrays().size()]);
    }

    @Override
    public INDArrayStatistics inspectArray(@NonNull INDArray array) {
        val debugInfo = new Nd4jCuda.DebugInfo();
        val ctx = AtomicAllocator.getInstance().getDeviceContext();
        AtomicAllocator.getInstance().synchronizeHostData(array);

        if (extraz.get() == null)
            extraz.set(new PointerPointer(32));

        val extras = extraz.get().put(
                null,
                ctx.getOldStream(),
                AtomicAllocator.getInstance().getDeviceIdPointer(),
                ctx.getBufferAllocation(),
                ctx.getBufferReduction(),
                ctx.getBufferScalar(),
                ctx.getBufferSpecial());


        nativeOps.inspectArray(extras, AtomicAllocator.getInstance().getHostPointer(array), (LongPointer) AtomicAllocator.getInstance().getHostPointer(array.shapeInfoDataBuffer()), AtomicAllocator.getInstance().getPointer(array, ctx), (LongPointer) AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()), debugInfo);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        return INDArrayStatistics.builder()
                .minValue(debugInfo._minValue())
                .maxValue(debugInfo._maxValue())
                .meanValue(debugInfo._meanValue())
                .stdDevValue(debugInfo._stdDevValue())
                .countInf(debugInfo._infCount())
                .countNaN(debugInfo._nanCount())
                .countNegative(debugInfo._negativeCount())
                .countPositive(debugInfo._positiveCount())
                .countZero(debugInfo._zeroCount())
                .build();
    }


    @Override
    public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, boolean empty) {
        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val dbf = nativeOps.shapeBuffer(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, empty);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val result = new CudaLongDataBuffer(nativeOps.getConstantShapeBufferPrimary(dbf), nativeOps.getConstantShapeBufferSpecial(dbf), Shape.shapeInfoLength(shape.length));


        return result;
    }

    @Override
    public DataBuffer createShapeInfo(long[] shape, long[] stride, long elementWiseStride, char order, DataType dtype, long extras) {
        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val dbf = nativeOps.shapeBufferEx(shape.length, new LongPointer(shape), new LongPointer(stride), dtype.toInt(), order, elementWiseStride, extras);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val result = new CudaLongDataBuffer(nativeOps.getConstantShapeBufferPrimary(dbf), nativeOps.getConstantShapeBufferSpecial(dbf), Shape.shapeInfoLength(shape.length));


        return result;
    }

    @Override
    public TadPack tadShapeInfoAndOffsets(INDArray array, long[] dimension) {
        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        OpaqueTadPack pack = nativeOps.tadOnlyShapeInfo((LongPointer) array.shapeInfoDataBuffer().addressPointer(), new LongPointer(ArrayUtil.toLongArray(dimension)), dimension.length);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val tadShape = new CudaLongDataBuffer(nativeOps.getPrimaryShapeInfo(pack), nativeOps.getSpecialShapeInfo(pack), nativeOps.getShapeInfoLength(pack));
        val tadOffsets = new CudaLongDataBuffer(nativeOps.getPrimaryOffsets(pack), nativeOps.getSpecialOffsets(pack), nativeOps.getNumberOfTads(pack));


        return new TadPack(tadShape, tadOffsets);
    }

    @Override
    public DataBuffer createConstantBuffer(long[] values, DataType desiredType) {
        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val dbf = nativeOps.constantBufferLong(desiredType.toInt(), new LongPointer(values), values.length);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
        buffer.setConstant(true);

        return buffer;
    }

    @Override
    public DataBuffer createConstantBuffer(double[] values, DataType desiredType)  {
        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val dbf = nativeOps.constantBufferDouble(desiredType.toInt(), new DoublePointer(values), values.length);

        if (nativeOps.lastErrorCode() != 0)
            throw new RuntimeException(nativeOps.lastErrorMessage());

        val buffer = Nd4j.createBuffer(nativeOps.getConstantDataBufferPrimary(dbf), nativeOps.getConstantDataBufferSpecial(dbf), values.length, desiredType);
        buffer.setConstant(true);

        return buffer;
    }

    @Override
    public int useCount(DataBuffer buffer){
        return nativeOps.dbUseCount(((BaseCudaDataBuffer) buffer).getOpaqueDataBuffer());
    }


}