deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/abstracts/Nd4jWorkspace.java

Summary

Maintainability
F
5 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.api.memory.abstracts;

import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.memory.pointers.PointersPair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.memory.MemoryManager;
import org.nd4j.common.util.ND4JFileUtils;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

@Slf4j
public abstract class Nd4jWorkspace implements MemoryWorkspace {
    @Getter
    protected int deviceId;
    @Getter
    protected Long threadId;

    protected Type workspaceType = Type.SCOPED;

    public static final long SAFETY_OFFSET = 1024L;

    @Getter
    protected String id;

    protected AtomicLong currentSize = new AtomicLong(0);
    protected AtomicLong hostOffset = new AtomicLong(0);
    protected AtomicLong deviceOffset = new AtomicLong(0);

    protected PointersPair workspace = new PointersPair();

    protected MemoryManager memoryManager;

    protected AtomicBoolean isLearning = new AtomicBoolean(true);
    protected AtomicBoolean isUsed = new AtomicBoolean(true);

    protected AtomicLong disabledCounter = new AtomicLong(0);


    protected AtomicLong cyclesCount = new AtomicLong(0);
    protected AtomicLong stepsCount = new AtomicLong(0);
    protected int stepsNumber = 1;

    protected AtomicLong lastCycleAllocations = new AtomicLong(0);
    protected AtomicLong cycleAllocations = new AtomicLong(0);
    protected AtomicLong spilledAllocationsSize = new AtomicLong(0);
    protected AtomicLong pinnedAllocationsSize = new AtomicLong(0);
    protected AtomicLong maxCycle = new AtomicLong(0);
    protected AtomicBoolean resetPlanned = new AtomicBoolean(false);
    protected AtomicBoolean isOpen = new AtomicBoolean(false);
    protected AtomicBoolean isInit = new AtomicBoolean(false);
    protected AtomicBoolean isOver = new AtomicBoolean(false);
    protected AtomicBoolean isBorrowed = new AtomicBoolean(false);

    protected AtomicInteger tagScope = new AtomicInteger(0);

    protected AtomicBoolean isDebug = new AtomicBoolean(false);
    protected AtomicInteger externalCount = new AtomicInteger(0);
    protected AtomicInteger pinnedCount = new AtomicInteger(0);

    protected AtomicBoolean trimmedMode = new AtomicBoolean(false);
    protected AtomicLong trimmedStep = new AtomicLong(0);

    @Getter
    protected final WorkspaceConfiguration workspaceConfiguration;

    // external allocations are purged at the end of loop
    protected List<PointersPair> externalAllocations = new ArrayList<>();

    // pinned allocations are purged with delay, used for circular mode only
    protected Queue<PointersPair> pinnedAllocations = new LinkedTransferQueue<>();

    @Setter
    protected MemoryWorkspace previousWorkspace;
    protected MemoryWorkspace borrowingWorkspace;

    protected AtomicLong initialBlockSize = new AtomicLong(0);

    protected String guid;

    protected File tempFile;

    protected AtomicLong generationId = new AtomicLong(0);

    // this field is used as alignment base for all allocations within this workspace
    public final static int alignmentBase = 32;

    // this memory manager implementation will be used to allocate real memory for this workspace

    public Nd4jWorkspace(@NonNull WorkspaceConfiguration configuration) {
        this(configuration, DEFAULT_ID);
    }

    public Nd4jWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String workspaceId) {
        this.workspaceConfiguration = configuration;
        this.id = workspaceId;
        this.threadId = Thread.currentThread().getId();
        this.guid = Nd4j.getWorkspaceManager().getUUID();
        this.memoryManager = Nd4j.getMemoryManager();
        this.deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        AllocationsTracker.getInstance().registerWorkspace(this.id);
        // and actual workspace allocation
        currentSize.set(workspaceConfiguration.getInitialSize());

        if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED)
            workspaceType = Type.CIRCULAR;
        else
            workspaceType = Type.SCOPED;

        if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED
                && workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE) {
            if (workspaceConfiguration.getOverallocationLimit() < 1.0)
                throw new ND4JIllegalStateException(
                        "For cyclic workspace overallocation should be positive integral value.");

            stepsNumber = (int) (workspaceConfiguration.getOverallocationLimit() + 1);
            log.trace("Steps: {}", stepsNumber);
        }


        // validate mmap option
        if (configuration.getPolicyLocation() == LocationPolicy.MMAP) {
            // file path should be either non-null
            if (configuration.getTempFilePath() != null) {
                tempFile = new File(configuration.getTempFilePath());

                if (tempFile.length() == 0 || tempFile.length() < configuration.getInitialSize()) {
                    if (configuration.getInitialSize() > 0) {
                        try {
                            fillFile(tempFile, configuration.getInitialSize());
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    } else {
                        throw new ND4JIllegalStateException("Memory-mapped file should have positive length.");
                    }
                } else {
                    configuration.setInitialSize(tempFile.length());
                }
            } else if (configuration.getInitialSize() > 0) {
                try {
                    tempFile = ND4JFileUtils.createTempFile("workspace", "tempMMAP");
                    tempFile.deleteOnExit();

                    // fill temp file with zeroes, up to initialSize bytes
                    fillFile(tempFile, configuration.getInitialSize());

                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            } else
                throw new ND4JIllegalStateException("MMAP target file path should be non-null or workspace initialSize should be >0 for temp file");
        }

        init();
    }

    @Override
    public Type getWorkspaceType() {
        return this.workspaceType;
    }

    public static void fillFile(File file, long length) throws Exception {
        byte[] buffer = new byte[16384];
        for (int i = 0; i < buffer.length; i++) {
            buffer[i] = (byte) 0;
        }

        try (FileOutputStream fos = new FileOutputStream(file); BufferedOutputStream bos = new BufferedOutputStream(fos)) {
            long written = 0;
            while (written < length) {
                fos.write(buffer);
                written += buffer.length;
            }
        }
    }

    @Override
    public long getGenerationId() {
        return generationId.get();
    }

    /**
     * This method returns step number. Viable only in circular mode.
     * @return
     */
    public long getStepNumber() {
        return stepsCount.get();
    }

    /**
     * This method returns number of bytes in spilled allocations.
     * @return
     */
    public long getSpilledSize() {
        return spilledAllocationsSize.get();
    }

    /**
     * This method returns number of bytes in pinned allocations.
     * @return
     */
    public long getPinnedSize() {
        return pinnedAllocationsSize.get();
    }

    /**
     * This method returns number of bytes for first block of circular workspace.
     * @return
     */
    public long getInitialBlockSize() {
        return initialBlockSize.get();
    }

    /**
     * This method returns parent Workspace, if any. Null if there's none.
     *
     * @return
     */
    @Override
    public MemoryWorkspace getParentWorkspace() {
        return previousWorkspace;
    }

    /**
     * This method returns current device memory offset within workspace
     * @return
     */
    public long getDeviceOffset() {
        return deviceOffset.get();
    }

    /**
     * This method returns current host memory offset within workspace
     * @return
     */
    public long getHostOffset() {
        return hostOffset.get();
    }

    /**
     * This method returns current amount of memory allocated for workspace.
     *
     * PLEASE NOTE: It shows only amount of HOST memory.
     * If current backend assumes DEVICE/HOST memory pair,
     * DEVICE memory will probably have the same size, but won't be accounted in this value.
     * @return
     */
    public long getCurrentSize() {
        return currentSize.get();
    }

    @Override
    public long getCurrentOffset() {
        return hostOffset.get();
    }

    protected void init() {
        // in case of MMAP we don't want any learning applied
        if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP && workspaceConfiguration.getPolicyLearning() != LearningPolicy.NONE)
            throw new IllegalArgumentException("Workspace backed by memory-mapped file can't have LearningPolicy defined");

        // we don't want overallocation in case of MMAP
        if (currentSize.get() > 0 && workspaceConfiguration.getPolicyLocation() != LocationPolicy.MMAP) {
            if (!isOver.get()) {
                if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE
                        && workspaceConfiguration.getOverallocationLimit() > 0) {
                    currentSize.addAndGet((long) (currentSize.get() * workspaceConfiguration.getOverallocationLimit()));
                    isOver.set(true);
                }
            }

            if (workspaceConfiguration.getMaxSize() > 0 && currentSize.get() > workspaceConfiguration.getMaxSize())
                currentSize.set(workspaceConfiguration.getMaxSize());
        }
    }

    public PagedPointer alloc(long requiredMemory, DataType type, boolean initialize) {
        return alloc(requiredMemory, MemoryKind.HOST, type, initialize);
    }

    /**
     * This method enabled debugging mode for this workspace
     *
     * @param reallyEnable
     */
    @Override
    public void enableDebug(boolean reallyEnable) {
        this.isDebug.set(reallyEnable);
    }

    public abstract long requiredMemoryPerArray(INDArray arr);

    /**
     * Enforces 8 byte alignment for requested memory amounts.
     * @param requiredMemory the requested memory amount
     * @return
     */
    public static long alignMemory(long requiredMemory) {
        // we enforce 8 byte alignment to ensure CUDA doesn't blame us
        long div = requiredMemory % alignmentBase;
        if (div != 0)
            requiredMemory += (alignmentBase - div);
        return requiredMemory;
    }

    public PagedPointer alloc(long requiredMemory, MemoryKind kind, DataType type, boolean initialize) {

        /*
            just two options here:
            1) reqMem + hostOffset < totalSize, we just return pointer + offset
            2) go for either external spilled, or pinned allocation
         */

        long numElements = requiredMemory / Nd4j.sizeOfDataType(type);

        // we enforce 8 byte alignment to ensure CUDA doesn't blame us
        requiredMemory = alignMemory(requiredMemory);

        AllocationsTracker.getInstance().getTracker(this.id).allocate(type,kind,numElements,requiredMemory);

        // shortcut made to skip workspace
        if (!isUsed.get()) {
            if (disabledCounter.incrementAndGet() % 10 == 0)
                log.warn("Workspace was turned off, and wasn't enabled after {} allocations", disabledCounter.get());

            PagedPointer pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize),
                    numElements);

            externalAllocations.add(new PointersPair(pointer, null));
            AllocationsTracker.getInstance().getTracker(id).allocateExternal(type,kind,numElements,requiredMemory);
            return pointer;
        }

        /*
            Trimmed mode is possible for cyclic workspace mode. Used in AsyncDataSetIterator, MQ, etc.
            Basically idea is simple: if one of datasets coming out of iterator has size higher then expected - we should reallocate workspace to match this size.
            So, we switch to trimmed mode, and all allocations will be "pinned", and eventually workspace will be reallocated.
         */
        boolean trimmer = (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED
                && requiredMemory + cycleAllocations.get() > initialBlockSize.get()
                && initialBlockSize.get() > 0) || trimmedMode.get();

        if (trimmer && workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && !trimmedMode.get()) {
            trimmedMode.set(true);
            trimmedStep.set(stepsCount.get());
        }

        // if size is enough - allocate from workspace
        if (hostOffset.get() + requiredMemory <= currentSize.get() && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
            // just alignment to 8 bytes

            cycleAllocations.addAndGet(requiredMemory);
            long prevOffset = hostOffset.getAndAdd(requiredMemory);
            deviceOffset.set(hostOffset.get());

            PagedPointer ptr = workspace.getHostPointer().withOffset(prevOffset, numElements);

            if (isDebug.get())
                log.info("Workspace [{}]: Allocating array of {} bytes, capacity of {} elements, prevOffset: {}; currentOffset: {}; address: {}",
                        id, requiredMemory, numElements, prevOffset, hostOffset.get(), ptr.address());

            if (initialize)
                Pointer.memset(ptr, 0, requiredMemory);

            return ptr;
        } else {
            // if current workspace isn't enough - we allocate it separately as spilled (or pinned, in case of circular mode)

            // in case of circular mode - we just reset offsets, and start from the beginning of the workspace
            if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0
                    && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
                reset();
                resetPlanned.set(true);
                return alloc(requiredMemory, kind, type, initialize);
            }


            if (isDebug.get())
                log.info("Workspace [{}]: step: {}, spilled  {} bytes, capacity of {} elements", id, stepsCount.get(),
                        requiredMemory, numElements);

            switch (workspaceConfiguration.getPolicySpill()) {
                case REALLOCATE:
                case EXTERNAL:
                    cycleAllocations.addAndGet(requiredMemory);
                    if (!trimmer) {
                        externalCount.incrementAndGet();
                        AllocationsTracker.getInstance().getTracker(id).allocateSpilled(type,kind,numElements,requiredMemory);
                        AllocationsTracker.getInstance().getTracker(id).allocateExternal(type,kind,numElements,requiredMemory);
                        spilledAllocationsSize.addAndGet(requiredMemory);
                        PagedPointer pointer = new PagedPointer(
                                memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize),
                                numElements);

                        externalAllocations.add(new PointersPair(pointer, null));

                        return pointer;
                    } else {
                        pinnedCount.incrementAndGet();
                        AllocationsTracker.getInstance().getTracker(id).allocatePinned(type,kind,numElements,requiredMemory);
                        pinnedAllocationsSize.addAndGet(requiredMemory);
                        PagedPointer pointer = new PagedPointer(
                                memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize),
                                numElements);

                        pinnedAllocations.add(new PointersPair(stepsCount.get(), requiredMemory, pointer, null));


                        return pointer;
                    }
                case FAIL:
                default: {
                    throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
                }
            }
        }
    }

    public void free(Pointer pointer) {
        // no-op for main page(s), purge for external stuff
    }

    @Override
    public void initializeWorkspace() {
        // we can reallocate this workspace to larger size if that's needed and allowed by configuration
        if ((currentSize.get() < maxCycle.get() || currentSize.get() < cycleAllocations.get())
                && workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE
                && (workspaceConfiguration.getMaxSize() == 0
                || (maxCycle.get() < workspaceConfiguration.getMaxSize()))) {
            if (workspaceConfiguration.getPolicyReset() != ResetPolicy.ENDOFBUFFER_REACHED) {
                destroyWorkspace(true);
                isInit.set(false);
            }
        }

        // if we're in cyclic mode, we do reallocations only after 2 full cycles, to avoid race conditions
        if (trimmedMode.get() && trimmedStep.get() + 2 < stepsCount.get()) {
            destroyWorkspace(false);
            isInit.set(false);
            isOver.set(false);
        }

        if (!isInit.get())
            if (workspaceConfiguration.getPolicyLearning() != LearningPolicy.NONE) {
                if (workspaceConfiguration.getMaxSize() > 0)
                    currentSize.set(Math.min(maxCycle.get(), workspaceConfiguration.getMaxSize()));
                else
                    currentSize.set(maxCycle.get());

                // if we're on cyclic mode, let's add 30% to size, just to reduce number of reallocations
                if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED) {
                    currentSize.set((long) (currentSize.get() * 1.3));
                    currentSize.addAndGet(8 - (currentSize.get() % 8));
                    maxCycle.set(currentSize.get());
                }

                // we're updating single block size for circular mode, will be used for alignment later
                initialBlockSize.set(currentSize.get());

                // handling optional overallocation here, however it's usually good idea to use it everywhere, to avoid frequent realloc calls
                if (!isOver.get()) {
                    if (workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE
                            && workspaceConfiguration.getOverallocationLimit() > 0 && currentSize.get() > 0) {
                        currentSize.set(currentSize.get()
                                + (long) (currentSize.get() * workspaceConfiguration.getOverallocationLimit()));
                        isOver.set(true);
                    }
                }

                if (workspaceConfiguration.getMinSize() > 0 && currentSize.get() < workspaceConfiguration.getMinSize())
                    currentSize.set(workspaceConfiguration.getMinSize());

                // purge spilled allocations
                if (externalCount.get() > 0 && (workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT
                        || resetPlanned.get())) {
                    clearExternalAllocations();
                    resetPlanned.set(false);
                }

                // calling for implementation-specific workspace initialization. basically allocation happens there
                init();
            }
    }

    /**
     * This method returns number of spilled allocations, that can be purged at the end of block
     * @return
     */
    public int getNumberOfExternalAllocations() {
        return externalCount.get();
    }

    /**
     * This method returns number of pinned allocations, they can be purged after 2 steps.
     *
     * PLEASE NOTE: This method can return non-zero calues only for circular workspace mode
     * @return
     */
    public int getNumberOfPinnedAllocations() {
        return pinnedCount.get();
    }

    @Override
    public void destroyWorkspace() {
        destroyWorkspace(true);
    }


    /**
     * This method basically deallocates workspace memory
     *
     * @param extended
     */
    @Override
    public void destroyWorkspace(boolean extended) {
        if (workspace.getHostPointer() != null && workspace.getHostPointer().getOriginalPointer() != null
                && workspace.getHostPointer().getOriginalPointer() instanceof BytePointer)
            workspace.getHostPointer().getOriginalPointer().deallocate();

        workspace.setHostPointer(null);
        currentSize.set(0);
        reset();

        if (extended) {
            clearExternalAllocations();
        }
        //remove the workspace from tracking when done
        AllocationsTracker.getInstance().deregisterWorkspace(this.id);
    }

    /**
     * This method TEMPORARY enters this workspace, without reset applied
     *
     * @return
     */
    @Override
    public MemoryWorkspace notifyScopeBorrowed() {
        if (isBorrowed.get())
            throw new ND4JIllegalStateException("Workspace [" + id + "]: Can't borrow from borrowed workspace");

        borrowingWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        isBorrowed.set(true);

        Nd4j.getMemoryManager().setCurrentWorkspace(this);

        return this;
    }

    public long getCyclesCount() {
        return cyclesCount.get();
    }

    @Override
    public void close() {
        // first we check if this workspace was borrowed. if yes - just close without reset.
        if (isBorrowed.get()) {
            if (tagScope.get() > 0) {
                if (tagScope.decrementAndGet() == 0) {
                    Nd4j.getMemoryManager().setCurrentWorkspace(this);
                }
                return;
            }

            isBorrowed.set(false);
            Nd4j.getMemoryManager().setCurrentWorkspace(borrowingWorkspace);
            return;
        }

        // next we check, if the same workspace was opened multiple times sequentially. then we just decrement counter, without reset
        if (tagScope.get() > 0) {
            if (tagScope.decrementAndGet() == 0) {
                Nd4j.getMemoryManager().setCurrentWorkspace(this);
            }
            return;
        }

        // this is for safety. We have to be sure that no ops were left non-processed
        //Furthermore, need to commit before marking workspace as closed, to avoid (incorrectly) hitting scope panic
        Nd4j.getExecutioner().commit();

        // since this workspace block is finished, we restore previous one. Even if it's null
        Nd4j.getMemoryManager().setCurrentWorkspace(previousWorkspace);
        isOpen.set(false);

        // just counter for cycles/blocks
        cyclesCount.incrementAndGet();
        if (cyclesCount.get() > 1 & (cyclesCount.get() - 1) % stepsNumber == 0) {
            // this counter is for cyclic mode, it counts generations, full loops over buffer
            stepsCount.incrementAndGet();
        }
        /*
            Basically all we want here, is:
            1) memset primary page(s)
            2) purge external allocations
         */

        if (!isUsed.get()) {
            log.warn("Workspace was turned off, and wasn't ever turned on back again");
            isUsed.set(true);
        }

        // if during this cycle we've used more memory then before - increase max count. we'll use it in future for optional reallocation
        if (cycleAllocations.get() > maxCycle.get()) {
            if (isDebug.get())
                log.info("Workspace [{}] device_{}, current cycle: {}; max cycle: {}", id,
                        Nd4j.getAffinityManager().getDeviceForCurrentThread(), cycleAllocations.get(),
                        maxCycle.get());

            maxCycle.set(cycleAllocations.get());
        }

        // checking, if we should reallocate this workspace to higher amount of memory
        if (workspaceConfiguration.getPolicyLearning() != LearningPolicy.NONE && maxCycle.get() > 0) {
            //log.info("Delayed workspace {}, device_{} initialization starts...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread());

            // if we're going to resize - we're probably safe to purge spilled allocations
            if (externalCount.get() > 0 && (workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT
                    || resetPlanned.get())) {
                clearExternalAllocations();
                resetPlanned.set(false);
            }

            if ((workspaceConfiguration.getPolicyLearning() == LearningPolicy.OVER_TIME
                    && workspaceConfiguration.getCyclesBeforeInitialization() == cyclesCount.intValue())
                    || (workspaceConfiguration.getPolicyLearning() == LearningPolicy.FIRST_LOOP
                    && currentSize.get() == 0)) {
                //log.info("Initializing on cycle {}", cyclesCount.get());

                if (Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING)
                    initializeWorkspace();
            } else if (currentSize.get() > 0 && cycleAllocations.get() > 0
                    && workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE
                    && workspaceConfiguration.getPolicyReset() != ResetPolicy.ENDOFBUFFER_REACHED) {

                if (Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING)
                    initializeWorkspace();
            }
        }


        // clearing pinned allocations that are old enough
        if (pinnedCount.get() > 0)
            clearPinnedAllocations(false);

        // if we're in trimmed mode (preparing for reallocation of circular buffer) - we can do it 2 generations after
        if (trimmedMode.get() && trimmedStep.get() + 2 < stepsCount.get()) {
            initialBlockSize.set(maxCycle.get());
            initializeWorkspace();
            trimmedMode.set(false);
            trimmedStep.set(0);

            reset();
        }

        lastCycleAllocations.set(cycleAllocations.get());

        disabledCounter.set(0);


        if (workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT) {
            reset();
        } else if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED
                && currentSize.get() > 0) {

            // for variable input we want to ensure alignment to max block, to avoid accidental buffer overruns
            long diff = initialBlockSize.get() - cycleAllocations.get();

            // we don't care about offsets if that's trimmed mode, offsets will be reset anyway upon reallocation
            if (diff > 0 && !trimmedMode.get() && deviceOffset.get() > 0) {

                if (isDebug.get())
                    log.info("Worskpace [{}]: Align to [{}]; diff: [{}]; block size: [{}]; currentOffset: [{}]; workspaceSize: [{}]; trimmedMode: {}",
                            id, initialBlockSize.get(), diff, cycleAllocations.get(), deviceOffset.get(),
                            currentSize.get(), trimmedMode.get());

                deviceOffset.getAndAdd(diff);
                hostOffset.getAndAdd(diff);
            }
        }

        cycleAllocations.set(0);
    }

    protected abstract void clearPinnedAllocations(boolean extended);

    protected abstract void clearExternalAllocations();

    @Override
    public MemoryWorkspace notifyScopeEntered() {
        // we should block stuff since we're going to invalidate spilled allocations
        // TODO: block on spilled allocations probably?

        MemoryWorkspace prev = Nd4j.getMemoryManager().getCurrentWorkspace();

        // if we're opening the same workspace - just increase counter, and skip everything else
        if (prev == this && isOpen.get()) {
            tagScope.incrementAndGet();
            return this;
        }

        // we'll need this in close() call, to restore previous workspace (if any)
        previousWorkspace = prev;

        Nd4j.getMemoryManager().setCurrentWorkspace(this);
        isOpen.set(true);

        // resetting workspace to 0 offset (if anything), not applicable to circular mode, sure
        if (workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT) {
            reset();
        }

        // if we have any spilled allocations left from last cycle - purge them.
        if (externalCount.get() > 0
                && (workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT || resetPlanned.get())) {
            clearExternalAllocations();
            resetPlanned.set(false);
        }

        cycleAllocations.set(0);
        disabledCounter.set(0);

        generationId.incrementAndGet();

        return this;
    }

    /**
     * This method reset host/device offsets within workspace
     *
     * PLEASE NOTE: Never call this method unless you realize all consequences
     */
    public void reset() {
        //log.info("Resetting at device: {}; host: {};", deviceOffset.get(), hostOffset.get());
        hostOffset.set(0);
        deviceOffset.set(0);
    }

    protected abstract void resetWorkspace();

    /**
     * This method is shortcut to close() method
     *
     * @return
     */
    @Override
    public MemoryWorkspace notifyScopeLeft() {
        close();
        return this;
    }

    /**
     * This method allows to temporary disable this workspace, and issue allocations directly.
     * @param isEnabled
     */
    @Override
    public void toggleWorkspaceUse(boolean isEnabled) {
        isUsed.set(isEnabled);
    }

    /**
     * This method returns number of bytes allocated during last full cycle
     * @return
     */
    @Override
    public long getLastCycleAllocations() {
        return lastCycleAllocations.get();
    }

    /**
     * This method returns number of bytes allocated during THIS cycle
     * @return
     */
    @Override
    public long getThisCycleAllocations() {
        return cycleAllocations.get();
    }

    /**
     * This method returns number of bytes of biggest cycle
     * @return
     */
    @Override
    public long getMaxCycleAllocations() {
        return maxCycle.get();
    }

    /**
     * This method returns True if scope was opened, and not closed yet.
     *
     * @return
     */
    @Override
    public boolean isScopeActive() {
        return isOpen.get();
    }

    @Override
    public MemoryWorkspace tagOutOfScopeUse() {
        tagScope.incrementAndGet();
        return this;
    }

    @Override
    public String toString() {
        return "Nd4jWorkspace{" + "id='" + id + '\'' + ", currentSize=" + currentSize.get() + '}';
    }

}