deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/util/DeviceLocalNDArray.java

Summary

Maintainability
B
6 hrs
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.util;

import edu.umd.cs.findbugs.annotations.Nullable;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;

import java.util.Arrays;

@Slf4j
public class DeviceLocalNDArray extends DeviceLocal<INDArray> {

    public DeviceLocalNDArray() {
        this(false);
    }

    public DeviceLocalNDArray(boolean delayedMode) {
        super(delayedMode);
    }

    public DeviceLocalNDArray(INDArray array) {
        this(array, false);
    }

    public DeviceLocalNDArray(INDArray array, boolean delayedMode) {
        super(delayedMode);

        broadcast(array);
    }

    /**
     * This method returns object local to current deviceId
     *
     * @return
     */
    @Nullable
    @Override
    public synchronized INDArray get() {
        val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        val sourceId = updatesMap.get(deviceId).get();
        if (sourceId >= 0 && sourceId != deviceId) {
            // if updates map contains some deviceId - we should take updated array from there
            val newArray = Nd4j.create(delayedArray.dataType(), delayedArray.shape(), delayedArray.stride(), delayedArray.ordering());
            Nd4j.getMemoryManager().memcpy(newArray.data(), delayedArray.data());
            backingMap.put(deviceId, newArray);

            // reset updates flag
            updatesMap.get(deviceId).set(deviceId);


            // also check if all updates were consumed
            boolean allUpdated = true;
            for (int e = 0; e < numDevices; e++) {
                if (updatesMap.get(e).get() != e) {
                    allUpdated = false;
                    break;
                }
            }

            if (allUpdated)
                delayedArray = null;
        }
        return get(deviceId);
    }

    /**
     * This method duplicates array, and stores it to all devices
     *
     * PLEASE NOTE: this method is NOT atomic, so you must be sure no other threads are using this instance during the update
     * @param array
     */
    public synchronized void broadcast(INDArray array) {
        if (array == null)
            return;

        Preconditions.checkArgument(!array.isView() || array.elementWiseStride() != 1, "View can't be used in DeviceLocalNDArray");

        Nd4j.getExecutioner().commit();

        val config = OpProfiler.getInstance().getConfig();
        val locality = config.isCheckLocality();

        if (locality)
            config.setCheckLocality(false);
        val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();

        if (!delayedMode) {
            // in immediate mode we put data in

            for (int i = 0; i < numDevices; i++) {
                // if current thread equal to this device - we just save it, without duplication
                if (deviceId == i) {
                    set(i, array.detach());
                } else {
                    set(i, Nd4j.getAffinityManager().replicateToDevice(i, array));
                }

            }
        } else {
            // we're only updating this device
            set(Nd4j.getAffinityManager().getDeviceForCurrentThread(), array);
            delayedArray = array.dup(array.ordering()).detach();

            // and marking all other devices as stale, and provide id of device with the most recent array
            for (int i = 0; i < numDevices; i++) {
                if (i != deviceId) {
                    updatesMap.get(i).set(deviceId);
                }
            }
        }

        config.setCheckLocality(locality);
    }

    /**
     * This method updates
     *
     * PLEASE NOTE: this method is NOT atomic, so you must be sure no other threads are using this instance during the update
     * @param array
     */
    public synchronized void update(@NonNull INDArray array) {
        Preconditions.checkArgument(!array.isView() || array.elementWiseStride() != 1, "View can't be used in DeviceLocalNDArray");

        val numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        val device = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        val currentArray = backingMap.get(device);
        boolean wasDelayed = false;

        if (Arrays.equals(currentArray.shapeInfoJava(), array.shapeInfoJava())) {
            // if arrays are the same - we'll just issue memcpy
            for (int k = 0; k < numDevices; k++) {
                val lock = locksMap.get(k);
                try {
                    lock.writeLock().lock();
                    val v = backingMap.get(k);
                    if (v == null) {
                        if (!wasDelayed) {
                            delayedArray = array.dup(array.ordering()).detach();
                            wasDelayed = true;
                        }
                        updatesMap.get(k).set(device);
                        continue;
                    }

                    Nd4j.getMemoryManager().memcpy(v.data(), array.data());
                    Nd4j.getExecutioner().commit();
                } finally {
                    lock.writeLock().unlock();
                }
            }
        } else {
            // if arrays are not the same - we'll issue broadcast call
            broadcast(array);
        }
    }
}