
View on GitHub


1 day
Test Coverage
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  *
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************

package org.nd4j.jita.allocator.impl;

import lombok.Getter;
import lombok.NonNull;
import lombok.val;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.enums.Aggressiveness;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.time.Ring;
import org.nd4j.jita.allocator.time.rings.LockedRing;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.constant.ConstantProtector;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.handler.impl.CudaZeroHandler;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;

 * Just-in-Time Allocator for CUDA
 * This method is a basement for pre-allocated memory management for cuda.
 * Basically that's sophisticated garbage collector for both zero-copy memory, and multiple device memory.
 * There's multiple possible data movement directions, but general path is:
 * host memory (issued on JVM side) ->
 *          zero-copy pinned memory (which is allocated for everything out there) ->
 *                  device memory (where data gets moved from zero-copy, if used actively enough)
 * And the backward movement, if memory isn't used anymore (like if originating INDArray was trashed by JVM GC), or it's not popular enough to hold in device memory
 * Mechanism is as lock-free, as possible. This achieved using three-state memory state signalling: Tick/Tack/Toe.
 * Tick: memory chunk (or its part) is accessed on device
 * Tack: memory chink (or its part) device access session was finished
 * Toe: memory chunk is locked for some reason. Possible reasons:
 *              Memory synchronization is ongoing, host->gpu or gpu->host
 *              Memory relocation is ongoing, zero->gpu, or gpu->zero, or gpu->host
 *              Memory removal is ongoing.
 * So, basically memory being used for internal calculations, not interfered with manual changes (aka putRow etc), are always available without locks
 * @author
public class AtomicAllocator implements Allocator {
    private static final AtomicAllocator INSTANCE = new AtomicAllocator();

    private Configuration configuration;

    private transient MemoryHandler memoryHandler;

    // we have single tracking point for allocation points, since we're not going to cycle through it any time soon
    private Map<Long, AllocationPoint> allocationsMap = new ConcurrentHashMap<>();

    private static Logger log = LoggerFactory.getLogger(AtomicAllocator.class);

        locks for internal resources
    private ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock();
    private ReentrantReadWriteLock externalsLock = new ReentrantReadWriteLock();

    private final AtomicBoolean wasInitialised = new AtomicBoolean(false);

    private final Ring deviceLong = new LockedRing(30);
    private final Ring deviceShort = new LockedRing(30);

    private final Ring zeroLong = new LockedRing(30);
    private final Ring zeroShort = new LockedRing(30);

    public static AtomicAllocator getInstance() {
        if (INSTANCE == null)
            throw new RuntimeException("AtomicAllocator is NULL");
        return INSTANCE;

    protected static ConstantProtector protector;

    private AtomicAllocator() {
        this.configuration = CudaEnvironment.getInstance().getConfiguration();

        this.memoryHandler = new CudaZeroHandler();

        this.memoryHandler.init(configuration, this);

        this.protector = ConstantProtector.getInstance();


    protected Map<Long, AllocationPoint> allocationsMap(){
        return allocationsMap;

    public void applyConfiguration() {







     * This method returns CudaContext for current thread
     * @return
    public CudaContext getDeviceContext() {
        // FIXME: proper lock avoidance required here
        return memoryHandler.getDeviceContext();

     * This method specifies Mover implementation to be used internally
     * @param memoryHandler
    public void setMemoryHandler(@NonNull MemoryHandler memoryHandler) {

        this.memoryHandler = memoryHandler;
        this.memoryHandler.init(configuration, this);


     * Consume and apply configuration passed in as argument
     * PLEASE NOTE: This method should only be used BEFORE any calculations were started.
     * @param configuration configuration bean to be applied
    public void applyConfiguration(@NonNull Configuration configuration) {
        if (!wasInitialised.get()) {

            this.configuration = configuration;


     * Returns current Allocator configuration
     * @return current configuration
    public Configuration getConfiguration() {
        try {
            return configuration;
        } finally {

     * This method returns actual device pointer valid for current object
     * @param buffer
    public Pointer getPointer(@NonNull DataBuffer buffer, CudaContext context) {
        return memoryHandler.getDevicePointer(buffer, context);

    public Pointer getPointer(DataBuffer buffer) {
        return memoryHandler.getDevicePointer(buffer, getDeviceContext());

     * This method returns actual device pointer valid for specified shape of current object
     * @param buffer
     * @param shape
     * @param isView
    public Pointer getPointer(DataBuffer buffer, AllocationShape shape, boolean isView, CudaContext context) {
        return memoryHandler.getDevicePointer(buffer, context);

     * This method returns actual device pointer valid for specified INDArray
     * @param array
    public Pointer getPointer(INDArray array, CudaContext context) {
        if (array.isEmpty())
            return null;

        return memoryHandler.getDevicePointer(, context);

     * This method returns actual host pointer valid for current object
     * @param array
    public Pointer getHostPointer(INDArray array) {
        if (array.isEmpty())
            return null;

        return memoryHandler.getHostPointer(;

     * This method returns actual host pointer valid for current object
     * @param buffer
    public Pointer getHostPointer(DataBuffer buffer) {
        return memoryHandler.getHostPointer(buffer);

     * This method should be called to make sure that data on host side is actualized
     * @param array
    public void synchronizeHostData(INDArray array) {
        if (array.isEmpty() || array.isS())

        val buffer = == null ? :;

     * This method should be called to make sure that data on host side is actualized
     * @param buffer
    public void synchronizeHostData(DataBuffer buffer) {
        // we actually need synchronization only in device-dependant environment. no-op otherwise. managed by native code

     * This method returns CUDA deviceId for specified buffer
     * @param array
     * @return
    public Integer getDeviceId(INDArray array) {
        return getAllocationPoint(array).getDeviceId();

     * This method releases memory allocated for this allocation point
     * @param point
    public void freeMemory(AllocationPoint point) {
        if (point.getAllocationStatus() == AllocationStatus.DEVICE) {

            if (point.getHostPointer() != null) {
                this.getMemoryHandler().forget(point, AllocationStatus.DEVICE);
        } else {
            // call it only once
            if (point.getHostPointer() != null) {
                this.getMemoryHandler().forget(point, AllocationStatus.HOST);


     * This method allocates required chunk of memory
     * @param requiredMemory
    public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, boolean initialize) {
        // by default we allocate on initial location
        AllocationPoint point = null;

        if (configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
            point = allocateMemory(buffer, requiredMemory, memoryHandler.getInitialLocation(), initialize);
        } else if (configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            // for DELAYED memory model we allocate only host memory, regardless of firstMemory configuration value
            point = allocateMemory(buffer, requiredMemory, AllocationStatus.HOST, initialize);

        return point;

     * This method allocates required chunk of memory in specific location
     * <p>
     * PLEASE NOTE: Do not use this method, unless you're 100% sure what you're doing
     * @param requiredMemory
     * @param location
    public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape requiredMemory, AllocationStatus location, boolean initialize) {
        switch(location) {
            case HOST:
                OpaqueDataBuffer opaqueDataBuffer = OpaqueDataBuffer.externalizedDataBuffer(buffer.length(), buffer.dataType(), buffer.pointer(),null);
                return new AllocationPoint(opaqueDataBuffer,requiredMemory.getNumberOfBytes());
            case DEVICE:
                OpaqueDataBuffer opaqueDataBuffer2 =  OpaqueDataBuffer.allocateDataBuffer(buffer.length(),buffer.dataType(),true);
                return new AllocationPoint(opaqueDataBuffer2,requiredMemory.getNumberOfBytes());
            case DELAYED:
            case CONSTANT:
            case UNDEFINED:
            case DEALLOCATED:
                throw new UnsupportedOperationException("Unable to allocate memory.");

     * This method returns AllocationPoint POJO for specified tracking ID
     * @param objectId
     * @return
    protected AllocationPoint getAllocationPoint(@NonNull Long objectId) {
        return allocationsMap.get(objectId);

     * This method frees native system memory referenced by specified tracking id/AllocationPoint
     * @param bucketId
     * @param objectId
     * @param point
     * @param copyback
    protected void purgeZeroObject(Long bucketId, Long objectId, AllocationPoint point, boolean copyback) {

        memoryHandler.purgeZeroObject(bucketId, objectId, point, copyback);

     * This method frees native device memory referenced by specified tracking id/AllocationPoint
     * @param threadId
     * @param deviceId
     * @param objectId
     * @param point
     * @param copyback
    protected void purgeDeviceObject(Long threadId, Integer deviceId, Long objectId, AllocationPoint point,
                                     boolean copyback) {
        memoryHandler.purgeDeviceObject(threadId, deviceId, objectId, point, copyback);

        // since we can't allow java object without native memory, we explicitly specify that memory is handled using HOST memory only, after device memory is released

     * This method seeks for unused zero-copy memory allocations
     * @param bucketId Id of the bucket, serving allocations
     * @return size of memory that was deallocated
    protected synchronized long seekUnusedZero(Long bucketId, Aggressiveness aggressiveness) {
        AtomicLong freeSpace = new AtomicLong(0);

        int totalElements = (int) memoryHandler.getAllocatedHostObjects(bucketId);

        // these 2 variables will contain jvm-wise memory access frequencies
        float shortAverage = zeroShort.getAverage();
        float longAverage = zeroLong.getAverage();

        // threshold is calculated based on agressiveness specified via configuration
        float shortThreshold = shortAverage / (Aggressiveness.values().length - aggressiveness.ordinal());
        float longThreshold = longAverage / (Aggressiveness.values().length - aggressiveness.ordinal());

        // simple counter for dereferenced objects
        AtomicInteger elementsDropped = new AtomicInteger(0);
        AtomicInteger elementsSurvived = new AtomicInteger(0);

        for (Long object : memoryHandler.getHostTrackingPoints(bucketId)) {
            AllocationPoint point = getAllocationPoint(object);

            // point can be null, if memory was promoted to device and was deleted there
            if (point == null)

            if (point.getAllocationStatus() == AllocationStatus.HOST) {

                    Check if memory points to non-existant buffer, using externals.
                    If externals don't have specified buffer - delete reference.
                if (point.getBuffer() == null) {
                    purgeZeroObject(bucketId, object, point, false);
                    throw new UnsupportedOperationException("Pew-pew");

                } else {

            } else {


        log.debug("Zero {} elements checked: [{}], deleted: {}, survived: {}", bucketId, totalElements,
                elementsDropped.get(), elementsSurvived.get());

        return freeSpace.get();

     * This method seeks for unused device memory allocations, for specified thread and device
     * @param threadId Id of the thread, retrieved via Thread.currentThread().getId()
     * @param deviceId Id of the device
     * @return size of memory that was deallocated
    protected long seekUnusedDevice(Long threadId, Integer deviceId, Aggressiveness aggressiveness) {
        AtomicLong freeSpace = new AtomicLong(0);

        //  int initialSize = allocations.size();

        // these 2 variables will contain jvm-wise memory access frequencies
        float shortAverage = deviceShort.getAverage();
        float longAverage = deviceLong.getAverage();

        // threshold is calculated based on agressiveness specified via configuration
        float shortThreshold = shortAverage / (Aggressiveness.values().length - aggressiveness.ordinal());
        float longThreshold = longAverage / (Aggressiveness.values().length - aggressiveness.ordinal());

        AtomicInteger elementsDropped = new AtomicInteger(0);
        AtomicInteger elementsMoved = new AtomicInteger(0);
        AtomicInteger elementsSurvived = new AtomicInteger(0);

        for (Long object : memoryHandler.getDeviceTrackingPoints(deviceId)) {
            AllocationPoint point = getAllocationPoint(object);
                Check if memory points to non-existent buffer, using externals.
                If externals don't have specified buffer - delete reference.
            if (point.getBuffer() == null) {
                if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
                    // we deallocate device memory
                    purgeDeviceObject(threadId, deviceId, object, point, false);

                    // and we deallocate host memory, since object is dereferenced

                    throw new UnsupportedOperationException("Unable to find device memory for null buffer!");
                } ;
            } else {


        log.debug("Thread/Device [" + threadId + "/" + deviceId + "] elements purged: [" + elementsDropped.get()
                + "]; Relocated: [" + elementsMoved.get() + "]; Survivors: [" + elementsSurvived.get() + "]");

        return freeSpace.get();

     * This method implements asynchronous memcpy, if that's available on current hardware
     * @param dstBuffer
     * @param srcPointer
     * @param length
     * @param dstOffset
    public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        this.memoryHandler.memcpyAsync(dstBuffer, srcPointer, length, dstOffset);

    public void memcpySpecial(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        this.memoryHandler.memcpySpecial(dstBuffer, srcPointer, length, dstOffset);

    public void memcpyDevice(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset,
                             CudaContext context) {
        this.memoryHandler.memcpyDevice(dstBuffer, srcPointer, length, dstOffset, context);

     * This method implements blocking memcpy
     * @param dstBuffer
     * @param srcPointer
     * @param length
     * @param dstOffset
    public void memcpyBlocking(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
        this.memoryHandler.memcpyBlocking(dstBuffer, srcPointer, length, dstOffset);

     * This method implements blocking memcpy
     * @param dstBuffer
     * @param srcBuffer
    public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
        this.memoryHandler.memcpy(dstBuffer, srcBuffer);

    public void tickHostWrite(DataBuffer buffer) {

    public void tickHostWrite(INDArray array) {

    public void tickDeviceWrite(INDArray array) {

    public AllocationPoint getAllocationPoint(INDArray array) {
        return getAllocationPoint(;

    public AllocationPoint getAllocationPoint(DataBuffer buffer) {
        return ((BaseCudaDataBuffer) buffer).getAllocationPoint();

     * This method returns deviceId for current thread
     * All values >= 0 are considered valid device IDs, all values < 0 are considered stubs.
     * @return
    public Integer getDeviceId() {
        return memoryHandler.getDeviceId();

    /** Returns {@link #getDeviceId()} wrapped as a {@link Pointer}. */
    public Pointer getDeviceIdPointer() {
        return new CudaPointer(getDeviceId());

    public void registerAction(CudaContext context, INDArray result, INDArray... operands) {
        memoryHandler.registerAction(context, result, operands);

    public FlowController getFlowController() {
        return memoryHandler.getFlowController();

    public DataBuffer getConstantBuffer(int[] array) {
        return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.INT);

    public DataBuffer getConstantBuffer(long[] array) {
        return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.LONG);

    public DataBuffer getConstantBuffer(float[] array) {
        return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.FLOAT);

    public DataBuffer getConstantBuffer(double[] array) {
        return Nd4j.getConstantHandler().getConstantBuffer(array, DataType.DOUBLE);

    public DataBuffer moveToConstant(DataBuffer dataBuffer) {
        return dataBuffer;